# Fast Sampling of Diffusion Models with Exponential Integrator

A clean implementation for DEIS and iPNDM

# Update

**BREAKING CHANGE**: v1.0 API changes greatly as we add`ρRK-DEIS`

and`ρAB-DEIS`

algorithms and more choice for time scheduling. If you are only interested in`tAB-DEIS`

/`iPNDM`

or previous codebase, check v0.1

# Usage

```
# for pytorch user
pip install "jax[cpu]"
```

## If diffusion models are trained with continuous time

```
import jax_deis as deis
def eps_fn(x_t, scalar_t):
vec_t = jnp.ones(x_t.shape[0]) * scalar_t
return eps_model(x_t, vec_t)
# pytorch
# import th_deis as deis
# def eps_fn(x_t, scalar_t):
# vec_t = (th.ones(x_t.shape[0])).float().to(x_t) * scalar_t
# with th.no_grad():
# return eps_model(x_t, vec_t)
# mappings between t and alpha in VPSDE
# we provide popular linear and cos mappings
t2alpha_fn,alpha2t_fn = deis.get_linear_alpha_fns(beta_0=0.01, beta_1=20)
vpsde = deis.VPSDE(
t2alpha_fn,
alpha2t_fn,
sampling_eps, # sampling end time t_0
sampling_T # sampling starting time t_T
)
sampler_fn = deis.get_sampler(
# args for diffusion model
vpsde,
eps_fn,
# args for timestamps scheduling
ts_phase="t", # support "rho", "t", "log"
ts_order=2.0,
num_step=10,
# deis choice
method = "t_ab", # deis sampling algorithms: support "rho_rk", "rho_ab", "t_ab", "ipndm"
ab_order= 3, # greater than 0, used for "rho_ab", "t_ab" algorithms, other algorithms will ignore the arg
rk_method="3kutta" # used for "rho_rk" algorithms, other algorithms will ignore the arg
)
sample = sampler_fn(noise)
```

## If diffusion models are trained with discrete time

```
#! by default the example assumes sampling
#! from t=len(discrete_alpha) - 1 to t=0
#! totaly len(discrete_alpha) steps if we use delta_t = 1
vpsde = deis.DiscreteVPSDE(discrete_alpha)
```

# A short derivation for DEIS

## Exponential integrator in diffusion model

The key insight of exponential integrator is taking advantage of all math structures present in ODEs. The goal is to reduce discretization error as small as possible.

The math structure in diffusion models includes semilinear structure, the analytic formula for drift and diffusion coefficients.

Below we present a short derivation for applications of the exponential integrator in diffusion model.

## Forward SDE

## Backward ODE

where

## Exponential Integrator

We can get rid of semilinear structure with **Exponential Integrator** by introducing a new variable

And ODE is simplified into

where

## Time scaling

We can take one step further when

where

## High order solver

By absorbing all math structure, we reach the following ODE

As RHS is a nerual network, we can not further simplify ODE unless we have knowledge for the black-box function. Then we can use well-established ODE solvers, such as multistep and runge kutta.

# Demo

- continuous vpsde Based on score_sde codebase. CIFAR10 images in 7 steps
- discrete vpsde Based on PNDM codebase

# Reference

```
@article{zhang2022fast,
title={Fast Sampling of Diffusion Models with Exponential Integrator},
author={Zhang, Qinsheng and Chen, Yongxin},
journal={arXiv preprint arXiv:2204.13902},
year={2022}
}
```