Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch

Overview

machine imagined fireworks

these fireworks do not exist

Video Diffusion - Pytorch (wip)

Text to video, it is happening! Official Project Page

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch. It uses a special space-time factored U-net, extending generation from 2d images to 3d videos

Status

Install

$ pip install video-diffusion-pytorch

Usage

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    num_frames = 5,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(1, 3, 5, 32, 32) # video (batch, channels, frames, height, width)
loss = diffusion(videos)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(batch_size = 4)
sampled_videos.shape # (4, 3, 5, 32, 32)

For conditioning on text, they derived text embeddings by first passing the tokenized text through BERT-large. Then you just have to train it like so

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    cond_dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    num_frames = 5,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(2, 3, 5, 32, 32) # video (batch, channels, frames, height, width)
text = torch.randn(2, 64)             # assume output of BERT-large has dimension of 64

loss = diffusion(videos, cond = text)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(cond = text)
sampled_videos.shape # (2, 3, 5, 32, 32)

You can also directly pass in the descriptions of the video as strings, if you plan on using BERT-base for text conditioning

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    use_bert_text_cond = True,  # this must be set to True to auto-use the bert model dimensions
    dim_mults = (1, 2, 4, 8),
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,    # height and width of frames
    num_frames = 5,     # number of video frames
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(3, 3, 5, 32, 32) # video (batch, channels, frames, height, width)

text = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

loss = diffusion(videos, cond = text)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(cond = text, cond_scale = 2)
sampled_videos.shape # (3, 3, 5, 32, 32)

Training

This repository also contains a handy Trainer class for training on a folder of gifs. Each gif must be of the correct dimensions image_size and num_frames.

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer

model = Unet3D(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
)

diffusion = GaussianDiffusion(
    model,
    image_size = 64,
    num_frames = 10,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
).cuda()

trainer = Trainer(
    diffusion,
    './data',                         # this folder path needs to contain all your training data, as .gif files, of correct image size and number of frames
    train_batch_size = 32,
    train_lr = 2e-5,
    save_and_sample_every = 1000,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)

trainer.train()

Sample videos (as gif files) will be saved to ./results periodically, as are the diffusion model parameters.

Co-training Images and Video

One of the claims in the paper is that by doing factored space-time attention, one can force the network to attend on the present for training images and video in conjunction, leading to better results.

It was not clear how they achieved this, but I furthered a guess.

To arrest attention to the present moment, simply pass focus_on_the_present = True on the diffusion forward method

loss = diffusion(videos, cond = text, focus_on_the_present = True)
loss.backward()

If you have a better idea how this is done, just open a github issue.

Todo

  • wire up text conditioning, use classifier free guidance
  • relative positional encodings in attention (space and time) - use T5 relative positional bias instead of what they used
  • add a forward keyword argument that arrests attention across time (as reported / claimed in the paper, this type of image + video simultaneous training improves results)
  • consider doing a 3d version of CLIP, so one can eventually apply the lessons of DALL-E2 to video https://github.com/lucidrains/dalle2-video
  • offer way for Trainer to curtail or pad frames, if gif is too long
  • find a good torchvideo-like library (torchvideo seems immature) for training on fireworks
  • project text into 4-8 tokens, and use them as memory key / values to condition both time and space in attention blocks
  • prepare a jax version for large scale TPU training
  • have Trainer take care of conditional video synthesis, with text offered as corresponding {video_filename}.txt within the same folder
  • see if ffcv or squirrel-core is a good fit

Citations

@misc{ho2022video,
  title   = {Video Diffusion Models}, 
  author  = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
  year    = {2022},
  eprint  = {2204.03458},
  archivePrefix = {arXiv},
  primaryClass = {cs.CV}
}
Comments
  • attn mask suppose be `~mask`

    attn mask suppose be `~mask`

    hi Phil, might be a small typo? https://github.com/lucidrains/video-diffusion-pytorch/blob/aded599bedb62f7d6c595b3486b1be4a0053106c/video_diffusion_pytorch/video_diffusion_pytorch.py#L321

    opened by CiaoHe 2
  • Elucidated version

    Elucidated version

    Is it possible to get a elucidated video diffusion model? or should I turn to your imagen repo? (see you have accomplish elucidate text2video in that repo)

    opened by martinriven 1
  • Unconditioned UNet

    Unconditioned UNet

    When creating the Unet without cond_dim argument an error is thrown.

    Error:

      File "/Users/shehan360/opt/anaconda3/envs/video-gpt/lib/python3.7/site-packages/video_diffusion_pytorch/video_diffusion_pytorch.py", line 320, in __init__
        cond_dim = time_dim + int(cond_dim)
    TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'
    

    To reproduce -

    model = Unet3D(
        dim = 64,
        dim_mults = (1, 2, 4, 8)
    )
    
    opened by shehan360 1
  • the commitment: one more residual

    the commitment: one more residual

    After updating this commitment, the color of the sampled video fades(not know why); i am using the ucf101 dataset, unconditional training with a 10k step warmup.

    opened by martinriven 28
  • Noisy output &

    Noisy output & "text_use_bert_cls" error

    The "name text_use_bert_cls is not defined" error occurs when trying to use explicit texts as mentioned in the 3rd example. The error occurs as the variable is not directly linked to the class in the function "p_losses". On fixing that, when I ran the code, the output samples generated are random noise. I ran the inference for 1K and 50K steps respectively. Can you please guide if I am missing any step.

    Attaching the output generated.

    1 .

    opened by GoutamKelam 30
  • Reason for combining rotary and relative positional embedding?

    Reason for combining rotary and relative positional embedding?

    Hi,

    Awesome work first of all. Is there a reason why you would combine both rotational as well as relative positional embedding in your Attention class? I would assume one of both is enough to incorporate the positions of the frames to the attention model?

    opened by oxjohanndiep 1
  • Duplicate dividing in relative positional encoding

    Duplicate dividing in relative positional encoding

    Hey @lucidrains, thanks for keeping these models implemented. In line 88 https://github.com/lucidrains/video-diffusion-pytorch/blob/f55f1b0824b1be7d2bb555ed7a5d612eff8ad5d0/video_diffusion_pytorch/video_diffusion_pytorch.py#L84-L88 you have max_exact as the half of num_buckets, whose value was already halved in line 84.

    I think that is duplicated and should be changed to identity:

     max_exact = num_buckets
    
    opened by SongweiGe 1
  • Gradient method for conditional sampling

    Gradient method for conditional sampling

    Thanks for your effort in implementation. I did not find any code blocks using autograd package to compute gradient as shown in Eq(6) of the paper. Have you implemented this technique? Or can you tell me which line corresponds to this technique. I am interested in its effectiveness. Thanks

    opened by LuckyDC 2
Releases(0.6.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Official PyTorch implementation for paper: Diffusion-GAN: Training GANs with Diffusion

Diffusion-GAN — Official PyTorch implementation Diffusion-GAN: Training GANs with Diffusion Zhendong Wang, Huangjie Zheng, Pengcheng He, Weizhu Chen a

Daniel 81 Sep 19, 2022
Pytorch implementation of diffusion models on Lie Groups for 6D grasp pose generation https://sites.google.com/view/se3dif/home

Pytorch implementation of Diffusion models in SE(3) for grasp and motion generation This library provides the tools for training and sampling diffusio

Julen Urain 14 Sep 17, 2022
Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch

Bit Diffusion - Pytorch Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch It seems like they misse

Phil Wang 91 Sep 19, 2022
Unconditional audio generation using diffusion models, in PyTorch.

Audio Diffusion - PyTorch Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different archi

archinet 368 Sep 22, 2022
Official implementation of MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation (https://arxiv.org/abs/2205.09853)

MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation Vikram Voleti*, Alexia Jolicoeur-Martineau*, Christopher Pal We

Vikram Voleti 83 Sep 23, 2022
A new adversarial purification method that uses the forward and reverse processes of diffusion models to remove adversarial perturbations.

Diffusion Models for Adversarial Purification Official PyTorch implementation of the ICML 2022 paper: Diffusion Models for Adversarial Purification We

NVIDIA Research Projects 68 Sep 21, 2022
Rebuilding and extending the Autoproxy tool in Python.

Proxyshop Photoshop scripting to generate high-quality Magic card renders, original concept developed by Chilli-Axe, rewritten in Python for extended

null 19 Sep 20, 2022
Talon Voice scripts for Butterscotch Shenanigans tools and workflows, extending `knausj_talon`

bscotch_talon Talon Voice scripts for Butterscotch Shenanigans tools and workflows, extending knausj_talon. Requirements Talon (beta) knausj_talon Pla

Butterscotch Shenanigans 1 Jul 27, 2022
Long-form text-to-images generation, using a pipeline of deep generative models (GPT-3 and Stable Diffusion)

Long Stable Diffusion: Long-form text to images e.g. story -> Stable Diffusion -> illustrations Right now, Stable Diffusion can only take in a short p

Sharon Zhou 360 Sep 18, 2022
Minimal diffusion model for generating MNIST, from 'Classifier-Free Diffusion Guidance'

Conditional Diffusion MNIST script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits

Tim Pearce 23 Sep 12, 2022
Stable Diffusion web UI - A browser interface based on Gradio library for Stable Diffusion

Stable Diffusion web UI A browser interface based on Gradio library for Stable Diffusion. Features Detailed feature showcase with images: Original txt

null 4.7k Sep 27, 2022
Self-contained, minimalistic implementation of diffusion models with Pytorch.

minDiffusion Goal of this educational repository is to provide a self-contained, minimalistic implementation of diffusion models using Pytorch. Many i

Simo Ryu 143 Sep 22, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 41 Sep 20, 2022
An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch.

k-diffusion An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch. This repo is a w

Katherine Crowson 509 Sep 27, 2022
Score-based Generative Models (Diffusion Models) for Speech Enhancement and Dereverberation

Speech Enhancement and Dereverberation with Diffusion-based Generative Models This repository contains the official PyTorch implementations for the 20

Signal Processing (SP), Universität Hamburg 86 Sep 14, 2022
This repository contains utilities for converting Keras models to PyTorch, converting TF models to Keras, and converting TF models to PyTorch.

weight-transfer This repository contains utilities for converting Keras models to PyTorch, converting TF models to Keras, and converting TF models to

Kira_Z 6 Sep 20, 2022
Repository for the paper: 'Diffusion-based Time Series Imputation and Forecasting with Structured State Space Models'

Diffusion-based Time Series Imputation and Forecasting with Structured State Space Models This is the official repository for the paper Diffusion-base

null 47 Sep 20, 2022
AI imagined images. Pythonic generation of stable diffusion images.

ImaginAIry ?? ?? AI imagined images. Pythonic generation of stable diffusion images. "just works" on Linux and OSX(M1). Examples >> pip install imagin

Bryce Drennan 529 Sep 24, 2022
MegEngine implementation of Diffusion Models.

MegDiffusion MegEngine implementation of Diffusion Models (in early development). Current maintainer: @MegChai Usage Infer with pre-trained models Now

旷视天元 MegEngine 6 Aug 5, 2022