# Jochastic: stochastically rounded operations between JAX tensors.

This repository contains a JAX software-based implementation of some stochastically rounded operations.

When encoding the weights of a neural network in low precision (such as `bfloat16`

), one runs into stagnation problems: updates end up being too small relative to the numbers the precision of the encoding. This leads to weights becoming stuck and the model's accuracy being significantly reduced.

Stochastic arithmetic lets you perform the operations in such a way that the weights have a non-zero probability of being modified anyway. This avoids the stagnation problem (see figure 4 of "Revisiting BFloat16 Training") without increasing the memory usage (as might happen if one were using a compensated summation to solve the problem).

The downside is that software-based stochastic arithmetic is significantly slower than normal floating-point arithmetic. It is thus viable for things like the weight update (when using the output of an Optax optimizer for example) but would not be appropriate in a hot loop.

Do not hesitate to submit an issue or a pull request if you need added functionalities for your needs!

## Usage

This repository introduces the `add`

and `tree_add`

operations. They take a PRNGkey and two tensors (or pytree respectively) to be added but round the result up or down randomly:

```
import jax
import jax.numpy as jnp
import jochastic
# problem definition
size = 10
dtype = jnp.bfloat16
key = jax.random.PRNGKey(1993)
# deterministic addition
key, keyx, keyy = jax.random.split(key, num=3)
x = jax.random.normal(keyx, shape=(size,), dtype=dtype)
y = jax.random.normal(keyy, shape=(size,), dtype=dtype)
result = x + y
print(f"deterministic addition: {result}")
# stochastic addition
result_sto = jochastic.add(key, x, y)
print(f"stochastic addition: {result_sto} ({result_sto.dtype})")
difference = result - result_sto
print(f"difference: {difference}")
```

Both functions take an optional `is_biased`

boolean parameter. If `is_biased`

is `True`

(the default value), the random number generator is biased according to the relative error of the operation else, it will round up half of the time on average.

Jitting the functions is left to the user's discretion (you will need to indicate that `is_biased`

is static).

**NOTE:** Very low precision (16 bits floating-point arithmetic or less) is *extremely* brittle. We recommend using higher precision locally (such as using 32 bits floating point arithmetic to compute the optimizer's update) *then* casting down to 16 bits at summing / storage time (something that Pytorch does transparently when using their `addcdiv`

in low precision). Both functions will accept mixed-precision inputs (adding a high precision number to a low precision), use that information for the rounding then return an output in the *lowest* precision of their inputs (contrary to most casting conventions).

## Implementation details

We use `TwoSum`

to measure the numerical error done by the addition, our tests show that it behaves as needed on `bfloat16`

(some edge cases might be invalid, leading to an inexact computation of the numerical error but, it is reliable enough for our purpose).

This and the `nextafter`

function let us emulate various rounding modes in software (this is inspired by Verrou's backend).

## Crediting this work

You can use this BibTeX reference if you use Jochastic within a published work:

```
@misc{Jochastic,
author = {Nestor, Demeure},
title = {Jochastic: stochastically rounded operations between JAX tensors.},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/nestordemeure/jochastic}}
}
```

You will find a Pytorch implementation called StochasTorch here.