TATS: A Long Video Generation Framework with Time-Agnostic VQGAN and Time-Sensitive Transformer

Related tags

Admin Panels TATS
Overview

Long Video Generation with Time-Agnostic VQGAN and Time-Sensitive Transformer

Project Website | Video | Paper

tl;dr We propose TATS, a long video generation framework that is trained on videos with tens of frames while it is able to generate videos with thousands of frames using sliding window.

Setup

  conda create -n tats python=3.8
  conda activate tats
  conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
  pip install pytorch-lightning==1.5.4
  pip install einops ftfy h5py imageio imageio-ffmpeg regex scikit-video tqdm

Datasets and trained models

UCF-101: official data, VQGAN, TATS-base
Sky-Timelapse: official data, VQGAN, TATS-base
Taichi-HD: official data, VQGAN, TATS-base

Usage

Synthesis

To sample the videos with the same length with the training data, use the code under scripts/ with following flags:

  • gpt_ckpt: path to the trained transformer checkpoint.
  • vqgan_ckpt: path to the trained VQGAN checkpoint.
  • save: path to the save the generation results.
  • save_videos: indicate that videos will be saved.
  • class_cond: indicate that class labels are used as conditional information.

To compute the FVD, these flags are required:

  • compute_fvd: indicate that FVD will be calculated.
  • data_path: path to the dataset folder.
  • dataset: dataset name.
  • image_folder: should be used when dataset contain frames instead of videos, e.g. Sky Time-lapse.
  • sample_every_n_frames: number of frames to skip in the real video data, e.g. please set it to 4 when training on the Taichi-HD dataset.
python sample_vqgan_transformer_short_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} --class_cond \
    --save {SAVEPATH} --data_path {DATAPATH} --batch_size 16 \
    --top_k 2048 --top_p 0.8 --dataset {DATANAME} --compute_fvd --save_videos

To sample the videos with the length longer than the training length with sliding window, use the following script.

  • sample_length: number of latent frames to be generated.
  • temporal_sample_pos: position of the frame that the sliding window approach generates.
python sample_vqgan_transformer_long_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
    --dataset ucf101 --class_cond --sample_length 16 --temporal_sample_pos 1 --batch_size 5 --n_sample 5 --save_videos

Training

Example usages of training the VQGAN and transformers are shown below. Explanation of the flags that are opt to change according to different settings:

  • data_path: path to the dataset folder.
  • default_root_dir: path to save the checkpoints and tensorboard logs.
  • vqvae: path to the trained VQGAN checkpoint.
  • resolution: resolution of the training videos clips.
  • sequence_length: frame number of the training videos clips.
  • discriminator_iter_start: the step id to start the GAN losses.
  • image_folder: should be used when dataset contain frames instead of videos, e.g. Sky Time-lapse.
  • unconditional: when no conditional information are available, e.g. Sky Time-lapse, use this flag.
  • sample_every_n_frames: number of frames to skip in the real video data, e.g. please set it to 4 when training on the Taichi-HD dataset.
  • downsample: sample rate in the dimensions of time, height and width.
  • no_random_restart: whether to re-initialize the codebook tokens.

VQGAN

python train_vqgan.py --embedding_dim 256 --n_codes 16384 --n_hiddens 16 --downsample 4 8 8 --no_random_restart \
                      --gpus 8 --sync_batchnorm --batch_size 2 --num_workers 6 --accumulate_grad_batches 6 \
                      --progress_bar_refresh_rate 500 --max_steps 2000000 --gradient_clip_val 1.0 --lr 3e-5 \
                      --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                      --resolution 64 --sequence_length 16 --discriminator_iter_start 10000 --norm_type batch \
                      --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1  --gan_feat_weight 4

TATS-base Transforemer

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 3 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 16 --max_steps 2000000

TATS-hierarchical Transforemer

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 3 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1280 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 20 --spatial_length 128 --n_unmasked 256 --max_steps 2000000

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 4 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 64 --sample_every_n_latent_frames 4 --spatial_length 128 --max_steps 2000000

Acknowledgments

Our code is partially built upon VQGAN and VideoGPT.

Citation

@article{ge2022long,
         title={Long Video Generation with Time-Agnostic VQGAN and Time-Sensitive Transformer},
         author={Ge, Songwei and Hayes, Thomas and Yang, Harry and Yin, Xi and Pang, Guan and Jacobs, David and Huang, Jia-Bin and Parikh, Devi},
         journal={arXiv preprint arXiv:2204.03638},
         year={2022}
}

License

TATS is licensed under the MIT license, as found in the LICENSE file.

Comments
  • Great Work!!!!!

    Great Work!!!!!

    Few Queries.....

    (a)Can you Please provide the evaluation code for reproducing Table 1(a), 1(b), 1(c) and 1(d). (b)Can you Please let me know the total computation hours needed to train the full model.

    opened by VIROBO-15 30
  • Pre-trained model for UCF-101 without class conditioning & Question about CCVS model for Fig.5(a)

    Pre-trained model for UCF-101 without class conditioning & Question about CCVS model for Fig.5(a)

    May I get the pretrained transformer model for the UCF-101 dataset without class conditioning? I found that the checkpoint you attached for the UCF-101 requires class labels.

    opened by Ugness 12
  • The dropout of Transformer

    The dropout of Transformer

    Dear authors:

    I found the dropout (embd_pdrop, resid_pdrop, attn_pdrop) is set to 0 during the GPT training. To verify my observations, I downloaded the TATS-base of UCF101 and Sky-Timelapse from the homepage, the embd_pdrop, resid_pdrop, attn_pdrop were all set to 0.

    0 means the dropout does not work. So I want to check is this correct? Or do I miss something?

    Kang

    opened by kangzhao2 7
  • Training on single GPU

    Training on single GPU

    Hi!

    Thanks for the great work. I have been trying to train on a single GPU but it keeps throwing this error:

    "Default process group has not been initialized, " RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

    Is it possible to configure the model to train on a single GPU?

    full error message:

    Traceback (most recent call last): File "/content/TATS/scripts/train_vqgan.py", line 70, in main() File "/content/TATS/scripts/train_vqgan.py", line 66, in main trainer.fit(model, data) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 738, in fit self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run self._dispatch() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch self.training_type_plugin.start_training(self) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage return self._run_train() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train self.fit_loop.run() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, **kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance self.epoch_loop.run(data_fetcher) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, **kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance batch_output = self.batch_loop.run(batch, batch_idx) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, **kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, **kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 219, in advance self.optimizer_idx, File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 259, in _run_optimization closure() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call self._result = self.closure(*args, **kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure step_output = self._step_fn() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step training_step_output = self.trainer.accelerator.training_step(step_kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step return self.training_type_plugin.training_step(*step_kwargs.values()) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 213, in training_step return self.model.training_step(*args, **kwargs) File "/content/TATS/scripts/tats/tats_vqgan.py", line 182, in training_step recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(x, optimizer_idx) File "/content/TATS/scripts/tats/tats_vqgan.py", line 118, in forward logits_image_fake, pred_image_fake = self.image_discriminator(frames_recon) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/content/TATS/scripts/tats/tats_vqgan.py", line 463, in forward res.append(model(res[-1])) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 731, in forward world_size = torch.distributed.get_world_size(process_group) File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 867, in get_world_size return _get_group_size(group) File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 325, in _get_group_size default_pg = _get_default_group() File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 430, in _get_default_group "Default process group has not been initialized, " RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

    opened by ndahlqvist 4
  • Cannot open saved model parameters for MUGEN

    Cannot open saved model parameters for MUGEN

    Hi @SongweiGe,

    Thx for sharing this work. While I was looking into the code about text-video generation, the links you provided for downloading models parameters VQGAN, TATS-base are not working. Can u share the accessible links? Thx in advance.

    opened by Crane-YU 2
  • Checkpoints for TATS-Hierarchical

    Checkpoints for TATS-Hierarchical

    Hi, thanks for the great work. Do you have any plan to share the checkpoints and the inference code for TATS-Hierarchical?

    If it's hard to share, may I get some metrics or values to monitor about TATS-Hierarchical so that I can reproduce them?

    opened by Ugness 2
  • Checkpoints of the interpolation Transformer and how to use it for inference?

    Checkpoints of the interpolation Transformer and how to use it for inference?

    Hi, thanks for the great work and the released codes! I'm very interested in it. Could you release the checkpoints of the interpolation Transformer and provide instructions on how to use them for inference? (which seems not provided in the current repository)

    opened by llyx97 2
  • What about the text-to-video generation performance if additional VQVAE is applied on text tokens?

    What about the text-to-video generation performance if additional VQVAE is applied on text tokens?

    Hi @SongweiGe, while reading the lines of code for encoding the text, I found identity layer is used to encode the text tokens, which already has been encoded by a tokenizer. Just wondering whether the performance would be increased if additional VQVAE is to be applied on the encoded tokens? Have you done sort of experiments?

    opened by Crane-YU 1
  • The Training of Interpolation Transformer

    The Training of Interpolation Transformer

    Dear author:

    In the training of Interpolation Transformer, given the latent space is 5 * 16 * 16, I found the first 16 * 16 and the last 16 * 16 tokens join the gradient propagation. But in the inference of Interpolation Transformer, the first and last 16 * 16 tokens are given. So, in my opinion, the first 16 * 16 and the last 16 * 16 tokens should not take part in gradient back-propagation during the training process? Please correct me if I'm wrong.

    Kang

    opened by kangzhao2 1
  • Environment Setting

    Environment Setting

    In my setting, using conda to install pytorch and pip to install pytorch-lighting caused unmattchable problems and failed runing .

    Here is my setting

      conda create -n tats python=3.8
      conda activate tats
      pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
      pip install pytorch-lightning
      pip install einops ftfy h5py imageio imageio-ffmpeg regex scikit-video tqdm av
    

    but I still meet some warning as follows:

    IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (259, 259) to (272, 272) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).
    [swscaler @ 0x6c83700] Warning: data is not aligned! This can lead to a speed loss
    

    and

    ~/miniconda3/envs/tats/lib/python3.8/site-packages/torchvision/io/video.py:162: UserWarning: The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.
      warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
    

    Does this affect the final result? Looking forward to your suggesstion and help。

    opened by ludanruan 1
Owner
null
[ECCV2022] Motion Sensitive Contrastive Learning for Self-supervised Video Representation

MSCL Official code for Motion Sensitive Contrastive Learning for Self-supervised Video Representation (ECCV2022). Introduction Contrastive learning ha

MEGVII Research 9 Jul 28, 2022
AI-powered art generator based on VQGAN+CLIP

AI-art-gen AI-powered art generator based on VQGAN+CLIP Open the jupyter notebook in Google Colab Prompt "Hindu_God" Aspect Ratios and Re-Sizing I str

ROCHIT 3 Jul 26, 2022
Pull sensitive data from users on windows including discord tokens and chrome data.

⭐ For a ?? Pegasus Pull sensitive data from users on windows including discord tokens and chrome data. Features ?? Discord tokens ?? Geolocation data

Addi 33 Sep 11, 2022
This script detects the technologies used in the target url and outputs sensitive files for these technologies.

SensFind - Sensitive Web Path Finder v1.0 Detects Web products used at the given URL. Searches sensitive files according to the detected product. Prin

Furkan Öztürk 9 Sep 20, 2022
Implementation of Transframer, Deepmind's U-net + Transformer architecture for up to 30 seconds video generation, in Pytorch

Transframer - Pytorch (wip) Implementation of Transframer, Deepmind's U-net + Transformer architecture for up to 30 seconds video generation, in Pytor

Phil Wang 40 Sep 19, 2022
Python CLI tool to redact sensitive data. 🔐📝

PyRedactKit ?? ?? CLI tool to redact sensitive information like ip address, email and dns. Features Redacts the following from your text files. ?? ✍️

Oaker Min 24 Aug 17, 2022
[TMC] Delay-Sensitive Energy-Efficient UAV Crowdsensing by Deep Reinforcement Learning

DRL-eFresh Additional materials for paper "Delay-Sensitive Energy-Efficient UAV Crowdsensing by Deep Reinforcement Learning" accepted in TMC. ?? Descr

null 4 May 31, 2022
after macie scan for sensitive information, using lambda to automatically tag S3 object based on customized label

Macie-auto-tag after macie scan for sensitive information, using lambda to automatically tag S3 object based on customized label 可以在lambda环境变量中自定数据的保密

jwyc 1 Sep 27, 2022
Official implementation of the ICML 2022 paper "Going Deeper into Permutation-Sensitive Graph Neural Networks"

Permutation Group Based Graph Neural Networks (PG-GNN) The official implementation of Going Deeper into Permutation-Sensitive Graph Neural Networks (I

Zhongyu Huang 12 Sep 11, 2022
A Python implementation of Locality Sensitive Hashing.

pyLSHash pyLSHash A fast Python implementation of locality sensitive hashing. I am using https://github.com/kayzhu/LSHash, but it stops to update sinc

郭飞 8 Sep 16, 2022
This script allows an attacker to search for sensitive files in a target's system

Credential Searcher Disclaimer This script is for educational purposes only, I don't endorse or promote it's illegal usage Table of Contents Overview

Favour Ndubuisi 1 Aug 7, 2022
✨Create Differetially Private Synthetic Data from Multiple Sensitive Data Sources✨

Differentially Private Synthetic Data from Multiple Private Data Sources ?? What we'll cover in this tutorial: developing with oblivious (OBLV), opend

Oblivious 1 Aug 12, 2022
Find sensitive information using dorks from different search-engines.

Find sensitive information using dorks from different search-engines. Agnee uses search_engines to find sensitive information about given domain using

Eshan Singh 45 Sep 15, 2022
Stable Diffusion Video to Video, Image to Image, Template Prompt Generation system and more, for use with any stable diffusion model

SDUtils: Stable Diffusion Utility Wrapper Stable Diffusion General utilities wrapper including: Video to Video, Image to Image, Template Prompt Genera

null 10 Sep 18, 2022
Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch

these fireworks do not exist Video Diffusion - Pytorch (wip) Text to video, it is happening! Official Project Page Implementation of Video Diffusion M

Phil Wang 392 Sep 27, 2022
Web3-focused long-tail MEV capture bot. Utilizing a suite of manual and automated strategies to provide routine alpha generation.

Longtails Longtails is focused on one thing, extracting maximum value from long-tail opportunities within Web3. These are not instant profit mechanism

chance 10 Sep 17, 2022
Long-form text-to-images generation, using a pipeline of deep generative models (GPT-3 and Stable Diffusion)

Long Stable Diffusion: Long-form text to images e.g. story -> Stable Diffusion -> illustrations Right now, Stable Diffusion can only take in a short p

Sharon Zhou 360 Sep 18, 2022
Next-generation Video instance recognition framework on top of Detectron2 which supports SeqFormer(ECCV Oral) and IDOL(ECCV Oral))

VNext: VNext is a Next-generation Video instance recognition framework on top of Detectron2. Currently it provides advanced online and offline video i

Junfeng Wu 413 Sep 19, 2022
Video Graph Transformer for Video Question Answering (ECCV'22)

VGT This is the pytorch implementation of our paper accepted to ECCV'22: Video Graph Transformer for Video Question Answering Environment Assume you h

Sea AI Lab 11 Sep 16, 2022