A JAX implementation of stochastic addition.

Overview

Jochastic: stochastically rounded operations between JAX tensors.

This repository contains a JAX software-based implementation of some stochastically rounded operations.

When encoding the weights of a neural network in low precision (such as bfloat16), one runs into stagnation problems: updates end up being too small relative to the numbers the precision of the encoding. This leads to weights becoming stuck and the model's accuracy being significantly reduced.

Stochastic arithmetic lets you perform the operations in such a way that the weights have a non-zero probability of being modified anyway. This avoids the stagnation problem (see figure 4 of "Revisiting BFloat16 Training") without increasing the memory usage (as might happen if one were using a compensated summation to solve the problem).

The downside is that software-based stochastic arithmetic is significantly slower than normal floating-point arithmetic. It is thus viable for things like the weight update (when using the output of an Optax optimizer for example) but would not be appropriate in a hot loop.

Do not hesitate to submit an issue or a pull request if you need added functionalities for your needs!

Usage

This repository introduces the add and tree_add operations. They take a PRNGkey and two tensors (or pytree respectively) to be added but round the result up or down randomly:

import jax
import jax.numpy as jnp
import jochastic

# problem definition
size = 10
dtype = jnp.bfloat16
key = jax.random.PRNGKey(1993)

# deterministic addition
key, keyx, keyy = jax.random.split(key, num=3)
x = jax.random.normal(keyx, shape=(size,), dtype=dtype)
y = jax.random.normal(keyy, shape=(size,), dtype=dtype)
result = x + y
print(f"deterministic addition: {result}")

# stochastic addition
result_sto = jochastic.add(key, x, y)
print(f"stochastic addition: {result_sto} ({result_sto.dtype})")
difference = result - result_sto
print(f"difference: {difference}")

Both functions take an optional is_biased boolean parameter. If is_biased is True (the default value), the random number generator is biased according to the relative error of the operation else, it will round up half of the time on average.

Jitting the functions is left to the user's discretion (you will need to indicate that is_biased is static).

NOTE: Very low precision (16 bits floating-point arithmetic or less) is extremely brittle. We recommend using higher precision locally (such as using 32 bits floating point arithmetic to compute the optimizer's update) then casting down to 16 bits at summing / storage time (something that Pytorch does transparently when using their addcdiv in low precision). Both functions will accept mixed-precision inputs (adding a high precision number to a low precision), use that information for the rounding then return an output in the lowest precision of their inputs (contrary to most casting conventions).

Implementation details

We use TwoSum to measure the numerical error done by the addition, our tests show that it behaves as needed on bfloat16 (some edge cases might be invalid, leading to an inexact computation of the numerical error but, it is reliable enough for our purpose).

This and the nextafter function let us emulate various rounding modes in software (this is inspired by Verrou's backend).

Crediting this work

You can use this BibTeX reference if you use Jochastic within a published work:

@misc{Jochastic,
  author = {Nestor, Demeure},
  title = {Jochastic: stochastically rounded operations between JAX tensors.},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/nestordemeure/jochastic}}
}

You will find a Pytorch implementation called StochasTorch here.

You might also like...

∞-AE model's implementation in JAX. Kernel-only method outperforms complicated SoTA models with a closed-form solution and a single hyper-parameter.

Infinite Recommendation Networks (∞-AE) This repository contains the implementation of ∞-AE from the paper "Infinite Recommendation Networks: A Data-C

Sep 25, 2022

Multimodal Masked Autoencoders (M3AE): A JAX/Flax Implementation

Multimodal Masked Autoencoders (M3AE): A JAX/Flax Implementation This is a JAX/Flax re-implementation for the paper Multimodal Masked Autoencoders Lea

Sep 27, 2022

Minimal Decision Transformer Implementation written in Jax (Flax).

Minimal Decision Transformer Implementation written in Jax (Flax). [Reference (minimal torch implementation)] Setup Set up the environments: pip insta

Sep 14, 2022

Queries on neural implicit surfaces via range analysis: ray casting, intersection, closest point, & more. SIGGRAPH 2022 paper. JAX implementation.

Queries on neural implicit surfaces via range analysis: ray casting, intersection, closest point, & more. SIGGRAPH 2022 paper. JAX implementation.

Perform geometric queries on neural implicit surfaces like ray casting, intersection testing, fast mesh extraction, closest points, and more. Works on

Sep 20, 2022

A simple hypernetwork implementation in jax using haiku.

jax-hypernetwork A simple hypernetwork implementation in jax using haiku. Example In this little demo, we create a linear hypernetwork to parametrise

Sep 11, 2022

This repository contains a Jax implementation of conformal training corresponding to the ICLR'22 paper "learning optimal conformal classifiers".

This repository contains a Jax implementation of conformal training corresponding to the ICLR'22 paper

Conformal training This repository contains a Jax implementation of conformal training corresponding to the follow paper: David Stutz, Krishnamurthy D

Sep 4, 2022

A home for audio ML in JAX. Has common features, learnable frontends, pretrained supervised and self-supervised models.

A home for audio ML in JAX. Has common features, learnable frontends, pretrained supervised and self-supervised models.

audax Sponsors About Installation Data pipeline What's available Audio feature extraction Network architectures Learnable Frontends Self-supervised mo

Aug 31, 2022

Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

Sep 19, 2022

A port of muP to JAX/Haiku

A port of muP to JAX/Haiku

MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg

Aug 17, 2022
Owner
Nestor Demeure
PhD, Engineer specialized in computer science and applied mathematics.
Nestor Demeure
This module provides and interface between JAX and Pint to allow JAX to support operations with units.

JAX + Units Built with JAX and Pint! This module provides and interface between JAX and Pint to allow JAX to support operations with units. The propag

Dan Foreman-Mackey 27 Aug 10, 2022
Baselax (Baselines + JAX) provides stable-baselines-style implementations of reinforcement learning (RL) algorithms with Google JAX framework.

Status: in development, no practical use. Baselax Baselax (Baselines + jax) provides stable-baselines-style implementations of reinforcement learning

sicer 3 Sep 7, 2022
SBX: Stable Baselines Jax (SB3 + Jax)

Stable Baselines Jax (SB3 + Jax = SBX) Proof of concept version of Stable-Baselines3 in Jax. Implemented algorithms: Truncated Quantile Critics (TQC)

Antonin RAFFIN 37 Sep 30, 2022
A traffic generator pentesting tool to generate random traffic with random mac and ip addresses in addition to random sequence numbers to a particular ip and port.

FREE Reverse Engineering Self-Study Course HERE traffic-generator A traffic generator pentesting tool to generate random traffic with random mac and i

Kevin Thomas 62 Sep 19, 2022
A sweet addition to your Blender model pages and Blender catalogs, Thangs-blender-addon is built to connect the Thangs repository to Blender as a mesh search tool!

Thangs-Blender-Addon Thangs' Add-on for Blender Thangs-Blender-Addon is an addition to the Blender model sites and Blender catalogs you already use! T

Randy Hucker 38 Sep 30, 2022
Implement pricing analytics and Monte Carlo simulations for stochastic volatility models including log-normal SV model, Heston

StochVolModels Implement pricing analytics and Monte Carlo simulations for stochastic volatility models including log-normal SV model and Heston SV mo

Artur Sepp 22 Sep 14, 2022
Source code for Neural Information Processing Systems (NeurIPS) 2022 paper "Stochastic Multiple Target Sampling Gradient Descent"

Stochastic Multiple Target Sampling Gradient Descent This repository contains the Pytorch implementation of Stochastic Multiple Target Sampling Gradie

Viet Hoang 6 Sep 21, 2022
Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"

The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint Louay Hazami   ·   Rayhane Mama   ·   Ragavan Thurairatn

Rayhane Mama 124 Sep 23, 2022
A pure-functional implementation of a machine learning transformer model in Python/JAX

A fully functional (pun intended) implementation of a machine learning transformer model in Python/JAX. I do realize that 'pure functional' and 'Pytho

Andrew Fitzgibbon 152 Aug 30, 2022
Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)

PaLM - Jax Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax using Equinox May as w

Phil Wang 141 Sep 7, 2022