Automatic Differentiation Systems#

Pyadjoint#

Pyadjoint is an extendable automatic differentiation framework written in Python. Each Pyadjoint computation is represented as a Block, and each Block knows how to recompute itself and how to compute its first and second derivatives. Computing derivatives on a set of computations is as simple as using the chain rule to combine each Block’s derivative.

Because Pyadjoint is written to be extendable, CRIKit uses it as its main differentiation tool and takes advantage of other automatic differentiation tools by wrapping them into the Pyadjoint framework. For example, CRIKit has a class that interfaces between Pyadjoint and a TensorFlow graph so that Pyadjoint can drive the derivative calculations without knowing the specifics of TensorFlow. Documentation on the CRIKit extensions to Pyadjoint is described here: Pyadjoint extension.

TensorFlow#

TensorFlow supports calculations on tensors and is most well-known for running neural networks and training them using backpropagation. The general method of using TensorFlow is to set up a static computation graph once and then evaluate the graph as many times as necessary. (This format in changed in TensorFlow 2.0, but CRIKit currently supports only versions 1.13 and 1.14).

Below is some example code for setting up a basic neural network graph.

import tensorflow.compat.v1 as tf
import numpy as np
from tensorflow.keras import layers

g = tf.Graph()
with g.as_default():
    g_input = tf.placeholder(dtype=tf.float64, shape=(None, 2), name='input_data')

    z = g_input
    for layer_size in [40, 30, 20]:
        z = layers.Dense(layer_size, activation='tanh')(z)

    g_output = layers.Dense(1, activation='linear')(z)

    init = tf.global_variables_initializer()

    sess = tf.Session()
    sess.run(init)

Then the outputs of the graph can be evaluated by running the Session.

# Run the graph on some random input.
input = np.random.rand(15, 2)
output = sess.run(g_output, {g_input: input})

The function pyadjoint_utils.tensorflow_adjoint.run_tensorflow_graph() does the same computation as the sess.run() call above, but it also creates a Block for Pyadjoint to use to evaluate the derivative. The Block it creates simply calls the TensorFlow derivative routines whenever Pyadjoint needs the derivative.

from pyadjoint_utils.tensorflow_adjoint import run_tensorflow_graph
from pyadjoint.overloaded_type import create_overloaded_object as coo

with sess.as_default():
    pyadjoint_input = coo(input)
    output = run_tensorflow_graph(g_output, {g_input: pyadjoint_input})

JAX#

JAX provides “composable transformations of Python+NumPy programs”, which include differentiation, vectorization, parallelization, and JIT compilation (including to GPU/TPU). CRIKit exposes JAX functionality to users via the pyadjoint_utils.jax_adjoint.overload_jax() function, whose usage looks like: .. testcode:

from jax import numpy as np
import numpy as onp
from pyadjoint_utils import array, overload_jax, ReducedFunction, taylor_test, Control
from jax.tree_util import Partial as partial # or, from functools import partial

#differentiate with respect to both non-self arguments.
#Also, we want this function to recompute its internal linearization points
#when differentiated, instead of storing them
@partial(overload_jax,argnums=(0,1),checkpoint=True)
def my_func(x, y):
    return np.exp(np.trace(x @ y))


inputs = (array(onp.random.randn(3,3)),array(onp.random.randn(3,3)))
controls = [Control(x) for x in inputs]
output = my_func(*inputs)
rf = ReducedFunction(output,controls)
h = [1.0e-1 * array(onp.random.randn(*x.shape) for x in inputs)]
assert taylor_test(rf,inputs,h) >= 1.9

Building From Source#

If efficiency is important, we highly recommend you build JAX from source. XLA tries to pre-compile as many ops as possible (e.g. matrix multiplication) so it can link the compiled code into JIT-ed functions at runtime for faster compilation. However, this means that, if your machine supports certain vector instructions that the jaxlib binary was not pre-compiled to support, you will not get the best possible performance out of the jaxlib installed through pip. We have observed that code generated by an XLA built locally on a machine can be more than twice as fast as the XLA installed through pip in CPU-only mode. Furthermore, if you have a GPU or TPU available, pip -installed JAX may not support compiling code for that GPU or TPU, so in order to use it, you’ll have to build JAX from source.

To do that, you first need to download JAX:

git clone https://github.com/google/jax
cd jax

You’ll also need its dependencies: a C++ compiler (any of g++, clang++, or MSVC), a working Python 3 installation and development files, and several Python packages. On a Debian system, the dependencies can be installed with

sudo apt install g++ python python3-dev

And on any system, the Python dependencies can be installed with

python -m pip install numpy scipy six wheel

To build CPU-only jaxlib, run

python build/build.py
python -m pip install dist/*.whl

To build CUDA-enabled jaxlib, add the --enable_cuda flag to your build line (i.e. python build/build.py --enable_cuda), but otherwise do the same. To build with ROCm support, first ensure the following packages are installed:

rocm-dev miopen-hip rocfft rocblas hipsparse rocrand rocsolver hipblas rccl

then do

python build/build.py --enable_rocm --rocm_path /path/to/rocm-x.y.z
python -m pip install dist/*.whl

If you’re building a GPU-enabled jaxlib on Windows, we recommend you also read the official JAX documentation on building from source.

GPU Memory Allocation#

By default, JAX will preallocate 90% of currently-available GPU memory as soon as an op executes on the GPU in order to minimize overhead from allocations and memory fragmentation. However, this can cause out-of-memory error; if your GPU-enabled JAX process crashes with a segmentation fault shortly after starting, this is likely the reason. You can change this behavior by modifying the following environment variables, described in more detail here.

Setting XLA_PYTHON_CLIENT_PREALLOCATE=false will disable JAX’s preallocation, instead allocating GPU memory as needed. Setting XLA_PYTHON_CLIENT_MEM_FRACTION=.XX sets the preallocation fraction to XX% instead of 90%. Setting XLA_PYTHON_CLIENT_ALLOCATOR=platform causes JAX to not only allocate GPU memory on demand, but also deallocate it when done (instead of re-using it for future computations).

A common cause of JAX running out of memory is the case when you are running multiple MPI processes per available GPU. This can be fixed by setting XLA_PYTHON_CLIENT_MEM_FRACTION=.XX, where .XX = .90 / N and N is the number of MPI processes per GPU.

Autograd#

Autograd overloads NumPy functions so that their derivatives can be computed. Simply replace import numpy with import autograd.numpy to get access to derivatives through Autograd. Similarly to Pyadjoint and TensorFlow, the Autograd NumPy functions record all computations as nodes on a graph to facilitate calculating derivatives.

Below is example code that defines a function doing some NumPy calculations.

import autograd.numpy as np

def f(u,g,p):
    mu = (np.sum(g*g,axis=1) + 1.e-12)**((p-2.)/2.)
    return np.reshape(mu,mu.shape+(1,)) * g

The crikit.cr.point_map decorator can be used to make a function that uses Autograd be compatible with Pyadjoint.

from crikit.cr.autograd import point_map

@point_map(((-1,),(-1,2),()),(-1,2), bare=True)
def f(u,g,p):
    mu = (np.sum(g*g,axis=1) + 1.e-12)**((p-2.)/2.)
    return np.reshape(mu,mu.shape+(1,)) * g