DEIS: Fast Sampling of Diffusion Models with Exponential Integrator

Related tags

Admin Panels deis
Overview

Fast Sampling of Diffusion Models with Exponential Integrator

Qinsheng Zhang, Yongxin Chen

A clean implementation for DEIS and iPNDM

deis

Update

  • BREAKING CHANGE: v1.0 API changes greatly as we add ρRK-DEIS and ρAB-DEIS algorithms and more choice for time scheduling. If you are only interested in tAB-DEIS / iPNDM or previous codebase, check v0.1

Usage

# for pytorch user
pip install "jax[cpu]"

If diffusion models are trained with continuous time

import jax_deis as deis

def eps_fn(x_t, scalar_t):
    vec_t = jnp.ones(x_t.shape[0]) * scalar_t
    return eps_model(x_t, vec_t)

# pytorch
# import th_deis as deis
# def eps_fn(x_t, scalar_t):
#     vec_t = (th.ones(x_t.shape[0])).float().to(x_t) * scalar_t
#     with th.no_grad():
#         return eps_model(x_t, vec_t)

# mappings between t and alpha in VPSDE
# we provide popular linear and cos mappings
t2alpha_fn,alpha2t_fn = deis.get_linear_alpha_fns(beta_0=0.01, beta_1=20)

vpsde = deis.VPSDE(
    t2alpha_fn, 
    alpha2t_fn,
    sampling_eps, # sampling end time t_0
    sampling_T # sampling starting time t_T
)

sampler_fn = deis.get_sampler(
    # args for diffusion model
    vpsde,
    eps_fn,
    # args for timestamps scheduling
    ts_phase="t", # support "rho", "t", "log"
    ts_order=2.0,
    num_step=10,
    # deis choice
    method = "t_ab", # deis sampling algorithms: support "rho_rk", "rho_ab", "t_ab", "ipndm"
    ab_order= 3, # greater than 0, used for "rho_ab", "t_ab" algorithms, other algorithms will ignore the arg
    rk_method="3kutta" # used for "rho_rk" algorithms, other algorithms will ignore the arg
)

sample = sampler_fn(noise)

If diffusion models are trained with discrete time

#! by default the example assumes sampling 
#! from t=len(discrete_alpha) - 1 to t=0
#! totaly len(discrete_alpha) steps if we use delta_t = 1
vpsde = deis.DiscreteVPSDE(discrete_alpha)

A short derivation for DEIS

Exponential integrator in diffusion model

The key insight of exponential integrator is taking advantage of all math structures present in ODEs. The goal is to reduce discretization error as small as possible.

The math structure in diffusion models includes semilinear structure, the analytic formula for drift and diffusion coefficients.

Below we present a short derivation for applications of the exponential integrator in diffusion model.

Forward SDE

$$ dx = F_tx dt + G_td\mathbf{w} $$

Backward ODE

$$ dx = F_tx dt + 0.5 G_tG_t^T L_t^{-T} \epsilon(x, t) dt $$

where $L_t L_t^{T} = \Sigma_t$ and $\Sigma_t$ are variance of $p_{0t}(x_t | x_0)$ .

Exponential Integrator

We can get rid of semilinear structure with Exponential Integrator by introducing a new variable $y$

$$ y_t = \Psi(t) x_t \quad \Psi(t) = \exp{-\int_0^{t} F_\tau d \tau} $$

And ODE is simplified into

$$ \dot{y}_t = 0.5 \Psi(t) G_t G_t^T L_t^{-T} \epsilon(x(y), t) $$

where $x(y)$ maps $y_t$ to $x_t$ .

Time scaling

We can take one step further when $F_t, G_t$ are scalars by rescaling time

$$ \dot{v}_\rho = \epsilon(x(v), t(\rho)) $$

where $y_t = v_\rho$ and $d \rho = 0.5 \Psi(t) G_t G_t^T L_t^{-T} dt$ . And $x(v)$ maps $v_\rho$ to $x_t$ , $t(\rho)$ maps $\rho$ to $t$ .

High order solver

By absorbing all math structure, we reach the following ODE

$$ \dot{v}_\rho = \epsilon(x(v), t(\rho)) $$

As RHS is a nerual network, we can not further simplify ODE unless we have knowledge for the black-box function. Then we can use well-established ODE solvers, such as multistep and runge kutta.

Demo

Reference

@article{zhang2022fast,
  title={Fast Sampling of Diffusion Models with Exponential Integrator},
  author={Zhang, Qinsheng and Chen, Yongxin},
  journal={arXiv preprint arXiv:2204.13902},
  year={2022}
}
You might also like...

Stable Diffusion Video to Video, Image to Image, Template Prompt Generation system and more, for use with any stable diffusion model

SDUtils: Stable Diffusion Utility Wrapper Stable Diffusion General utilities wrapper including: Video to Video, Image to Image, Template Prompt Genera

Oct 17, 2022

Diffusion attentive attribution maps for interpreting Stable Diffusion.

What the DAAM: Interpreting Stable Diffusion Using Cross Attention Caveat: the codebase is in a bit of a mess. I plan to continue refactoring and poli

Nov 22, 2022

BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

Bilateral Denoising Diffusion Models (BDDMs) This is the official PyTorch implementation of the following paper: BDDM: BILATERAL DENOISING DIFFUSION M

Nov 17, 2022

Score-based Generative Models (Diffusion Models) for Speech Enhancement and Dereverberation

Score-based Generative Models (Diffusion Models) for Speech Enhancement and Dereverberation

Speech Enhancement and Dereverberation with Diffusion-based Generative Models This repository contains the official PyTorch implementations for the 20

Nov 19, 2022

Jax implementation for the paper "Sampling-based inference for large linear models, with application to linearised Laplace"

sampled-laplace This repository includes Jax code and experiments for the paper Sampling-based inference for large linear models, with application to

Oct 12, 2022

Official implementation of "Gradient-Guided Importance Sampling for Learning Binary Energy-Based Models" [Preprint]

Official implementation of

Gradient-Guided Importance Sampling for Learning Binary Energy-Based Models This is the official implementation of the RMwGGIS method proposed in the

Oct 17, 2022

Self-contained, minimalistic implementation of diffusion models with Pytorch.

minDiffusion Goal of this educational repository is to provide a self-contained, minimalistic implementation of diffusion models using Pytorch. Many i

Nov 25, 2022

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch

these fireworks do not exist Video Diffusion - Pytorch (wip) Text to video, it is happening! Official Project Page Implementation of Video Diffusion M

Nov 21, 2022

Library to lean big models combined with Text and Image. And then Diffusion!

Library to lean big models combined with Text and Image. And then Diffusion!

Mumoda Mumoda is an inference and finetuning library using Multi-Modality learning. It already contains some SOTA contractive learning model with mult

Nov 22, 2022
Comments
  • Fix potential nans in th_deis.DisVPSDE

    Fix potential nans in th_deis.DisVPSDE

    The coefficients will become NaN if 100 steps, quad, last=False (this is not an issue for uniform). This is because of multiple zeros will appear in the timestep schedules if num_timesteps is large. I fixed it by removing identical steps and adding additional small steps that are unique. This code should run successfully:

    import numpy as np
    from th_deis import DisVPSDE
    beta = np.linspace(0.0001, 0.02, 1000)
    alpha = 1 - np.cumprod(1 - beta)
    sde = DisVPSDE(alpha)
    rev_timesteps = sde.get_rev_timesteps(100, 'quad', False)
    assert rev_timesteps.shape[0] == 101
    np_ei_ab_coef = sde.get_deis_coef(3, rev_timesteps, 3)
    assert np.any(np.isnan(np_ei_ab_coef)) == False
    
    opened by jiamings 0
[ECCV 2022] Accelerating Score-based Generative Models with Preconditioned Diffusion Sampling

Accelerating Score-based Generative Models with Preconditioned Diffusion Sampling Paper Accelerating Score-based Generative Models with Preconditioned

Fudan Zhang Vision Group 26 Nov 6, 2022
Bayesian negative sampling is the theoretically optimal negative sampling algorithm that runs in linear time.

Bayesian Negative Sampling Required Packages numpy : Implement BNS for Matrix Factorization (MF) (run main_MF.py ); pytorch: Implement BNS for light G

HUST MinsLab 1 Aug 25, 2022
Bayesian negative sampling is the theoretically optimal negative sampling algorithm that runs in linear time.

Bayesian Negative Sampling Required Packages numpy : Implement BNS for Matrix Factorization (MF) (run main_MF.py ); pytorch: Implement BNS for light G

Liu Bin 38 Sep 13, 2022
A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model

EMA - Pytorch A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model import torch from ema_pytorch import EMA

Phil Wang 123 Nov 15, 2022
Code of the paper Reinforcement Learning with Non-Exponential Discounting

Reinforcement learning with Non-Exponential Discounting This repository is the official implementation of the paper Reinforcement learning with Non-Ex

Matthias Schultheis 1 Sep 23, 2022
Neural Graphical models are neural network based graphical models that offer richer representation, faster inference & sampling

Neural Graphical Models Neural Graphical Models (NGMs) attempt to represent complex feature dependencies with reasonable computational costs. Specific

Harsh Shrivastava 5 Oct 13, 2022
Official PyTorch implementation for paper: Diffusion-GAN: Training GANs with Diffusion

Diffusion-GAN — Official PyTorch implementation Diffusion-GAN: Training GANs with Diffusion Zhendong Wang, Huangjie Zheng, Pengcheng He, Weizhu Chen a

Daniel 174 Nov 21, 2022
Minimal diffusion model for generating MNIST, from 'Classifier-Free Diffusion Guidance'

Conditional Diffusion MNIST script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits

Tim Pearce 82 Nov 18, 2022
Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch

Bit Diffusion - Pytorch Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch It seems like they misse

Phil Wang 151 Nov 16, 2022
Stable Diffusion web UI - A browser interface based on Gradio library for Stable Diffusion

Stable Diffusion web UI A browser interface based on Gradio library for Stable Diffusion. Features Detailed feature showcase with images: Original txt

null 23k Nov 29, 2022