# Fast Sampling of Diffusion Models with Exponential Integrator

A clean implementation for DEIS and iPNDM

# Update

• BREAKING CHANGE: v1.0 API changes greatly as we add ρRK-DEIS and ρAB-DEIS algorithms and more choice for time scheduling. If you are only interested in tAB-DEIS / iPNDM or previous codebase, check v0.1

# Usage

# for pytorch user
pip install "jax[cpu]"

## If diffusion models are trained with continuous time

import jax_deis as deis

def eps_fn(x_t, scalar_t):
vec_t = jnp.ones(x_t.shape[0]) * scalar_t
return eps_model(x_t, vec_t)

# pytorch
# import th_deis as deis
# def eps_fn(x_t, scalar_t):
#     vec_t = (th.ones(x_t.shape[0])).float().to(x_t) * scalar_t
#         return eps_model(x_t, vec_t)

# mappings between t and alpha in VPSDE
# we provide popular linear and cos mappings
t2alpha_fn,alpha2t_fn = deis.get_linear_alpha_fns(beta_0=0.01, beta_1=20)

vpsde = deis.VPSDE(
t2alpha_fn,
alpha2t_fn,
sampling_eps, # sampling end time t_0
sampling_T # sampling starting time t_T
)

sampler_fn = deis.get_sampler(
# args for diffusion model
vpsde,
eps_fn,
# args for timestamps scheduling
ts_phase="t", # support "rho", "t", "log"
ts_order=2.0,
num_step=10,
# deis choice
method = "t_ab", # deis sampling algorithms: support "rho_rk", "rho_ab", "t_ab", "ipndm"
ab_order= 3, # greater than 0, used for "rho_ab", "t_ab" algorithms, other algorithms will ignore the arg
rk_method="3kutta" # used for "rho_rk" algorithms, other algorithms will ignore the arg
)

sample = sampler_fn(noise)

## If diffusion models are trained with discrete time

#! by default the example assumes sampling
#! from t=len(discrete_alpha) - 1 to t=0
#! totaly len(discrete_alpha) steps if we use delta_t = 1
vpsde = deis.DiscreteVPSDE(discrete_alpha)

# A short derivation for DEIS

Exponential integrator in diffusion model

The key insight of exponential integrator is taking advantage of all math structures present in ODEs. The goal is to reduce discretization error as small as possible.

The math structure in diffusion models includes semilinear structure, the analytic formula for drift and diffusion coefficients.

Below we present a short derivation for applications of the exponential integrator in diffusion model.

## Forward SDE

$$dx = F_tx dt + G_td\mathbf{w}$$

## Backward ODE

$$dx = F_tx dt + 0.5 G_tG_t^T L_t^{-T} \epsilon(x, t) dt$$

where $L_t L_t^{T} = \Sigma_t$ and $\Sigma_t$ are variance of $p_{0t}(x_t | x_0)$ .

## Exponential Integrator

We can get rid of semilinear structure with Exponential Integrator by introducing a new variable $y$

$$y_t = \Psi(t) x_t \quad \Psi(t) = \exp{-\int_0^{t} F_\tau d \tau}$$

And ODE is simplified into

$$\dot{y}_t = 0.5 \Psi(t) G_t G_t^T L_t^{-T} \epsilon(x(y), t)$$

where $x(y)$ maps $y_t$ to $x_t$ .

## Time scaling

We can take one step further when $F_t, G_t$ are scalars by rescaling time

$$\dot{v}_\rho = \epsilon(x(v), t(\rho))$$

where $y_t = v_\rho$ and $d \rho = 0.5 \Psi(t) G_t G_t^T L_t^{-T} dt$ . And $x(v)$ maps $v_\rho$ to $x_t$ , $t(\rho)$ maps $\rho$ to $t$ .

## High order solver

By absorbing all math structure, we reach the following ODE

$$\dot{v}_\rho = \epsilon(x(v), t(\rho))$$

As RHS is a nerual network, we can not further simplify ODE unless we have knowledge for the black-box function. Then we can use well-established ODE solvers, such as multistep and runge kutta.

# Reference

@article{zhang2022fast,
title={Fast Sampling of Diffusion Models with Exponential Integrator},
author={Zhang, Qinsheng and Chen, Yongxin},
journal={arXiv preprint arXiv:2204.13902},
year={2022}
}

### Stable Diffusion Video to Video, Image to Image, Template Prompt Generation system and more, for use with any stable diffusion model

SDUtils: Stable Diffusion Utility Wrapper Stable Diffusion General utilities wrapper including: Video to Video, Image to Image, Template Prompt Genera

Oct 17, 2022

### Diffusion attentive attribution maps for interpreting Stable Diffusion.

What the DAAM: Interpreting Stable Diffusion Using Cross Attention Caveat: the codebase is in a bit of a mess. I plan to continue refactoring and poli

Nov 22, 2022

### BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

Bilateral Denoising Diffusion Models (BDDMs) This is the official PyTorch implementation of the following paper: BDDM: BILATERAL DENOISING DIFFUSION M

Nov 17, 2022

### Score-based Generative Models (Diffusion Models) for Speech Enhancement and Dereverberation

Speech Enhancement and Dereverberation with Diffusion-based Generative Models This repository contains the official PyTorch implementations for the 20

Nov 19, 2022

### Jax implementation for the paper "Sampling-based inference for large linear models, with application to linearised Laplace"

sampled-laplace This repository includes Jax code and experiments for the paper Sampling-based inference for large linear models, with application to

Oct 12, 2022

### Official implementation of "Gradient-Guided Importance Sampling for Learning Binary Energy-Based Models" [Preprint]

Gradient-Guided Importance Sampling for Learning Binary Energy-Based Models This is the official implementation of the RMwGGIS method proposed in the

Oct 17, 2022

### Self-contained, minimalistic implementation of diffusion models with Pytorch.

minDiffusion Goal of this educational repository is to provide a self-contained, minimalistic implementation of diffusion models using Pytorch. Many i

Nov 25, 2022

### Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch

these fireworks do not exist Video Diffusion - Pytorch (wip) Text to video, it is happening! Official Project Page Implementation of Video Diffusion M

Nov 21, 2022

### Library to lean big models combined with Text and Image. And then Diffusion!

Mumoda Mumoda is an inference and finetuning library using Multi-Modality learning. It already contains some SOTA contractive learning model with mult

Nov 22, 2022
• #### Fix potential nans in th_deis.DisVPSDE

The coefficients will become NaN if 100 steps, quad, last=False (this is not an issue for uniform). This is because of multiple zeros will appear in the timestep schedules if num_timesteps is large. I fixed it by removing identical steps and adding additional small steps that are unique. This code should run successfully:

import numpy as np
from th_deis import DisVPSDE
beta = np.linspace(0.0001, 0.02, 1000)
alpha = 1 - np.cumprod(1 - beta)
sde = DisVPSDE(alpha)
assert rev_timesteps.shape[0] == 101
np_ei_ab_coef = sde.get_deis_coef(3, rev_timesteps, 3)
assert np.any(np.isnan(np_ei_ab_coef)) == False

opened by jiamings 0
Qinsheng Zhang
###### [ECCV 2022] Accelerating Score-based Generative Models with Preconditioned Diffusion Sampling

Accelerating Score-based Generative Models with Preconditioned Diffusion Sampling Paper Accelerating Score-based Generative Models with Preconditioned

26 Nov 6, 2022
###### Bayesian negative sampling is the theoretically optimal negative sampling algorithm that runs in linear time.

Bayesian Negative Sampling Required Packages numpy : Implement BNS for Matrix Factorization (MF) (run main_MF.py ); pytorch: Implement BNS for light G

1 Aug 25, 2022
###### Bayesian negative sampling is the theoretically optimal negative sampling algorithm that runs in linear time.

Bayesian Negative Sampling Required Packages numpy : Implement BNS for Matrix Factorization (MF) (run main_MF.py ); pytorch: Implement BNS for light G

38 Sep 13, 2022
###### A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model

EMA - Pytorch A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model import torch from ema_pytorch import EMA

123 Nov 15, 2022
###### Code of the paper Reinforcement Learning with Non-Exponential Discounting

Reinforcement learning with Non-Exponential Discounting This repository is the official implementation of the paper Reinforcement learning with Non-Ex

1 Sep 23, 2022
###### Neural Graphical models are neural network based graphical models that offer richer representation, faster inference & sampling

Neural Graphical Models Neural Graphical Models (NGMs) attempt to represent complex feature dependencies with reasonable computational costs. Specific

5 Oct 13, 2022
###### Official PyTorch implementation for paper: Diffusion-GAN: Training GANs with Diffusion

Diffusion-GAN — Official PyTorch implementation Diffusion-GAN: Training GANs with Diffusion Zhendong Wang, Huangjie Zheng, Pengcheng He, Weizhu Chen a

174 Nov 21, 2022
###### Minimal diffusion model for generating MNIST, from 'Classifier-Free Diffusion Guidance'

Conditional Diffusion MNIST script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits

82 Nov 18, 2022
###### Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch

Bit Diffusion - Pytorch Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch It seems like they misse

151 Nov 16, 2022
###### Stable Diffusion web UI - A browser interface based on Gradio library for Stable Diffusion

Stable Diffusion web UI A browser interface based on Gradio library for Stable Diffusion. Features Detailed feature showcase with images: Original txt

23k Nov 29, 2022