(crikit-automatic-differentiation)= # Automatic Differentiation Systems ## Pyadjoint [Pyadjoint][pyadjoint] is an extendable automatic differentiation framework written in Python. Each Pyadjoint computation is represented as a Block, and each Block knows how to recompute itself and how to compute its first and second derivatives. Computing derivatives on a set of computations is as simple as using the chain rule to combine each Block's derivative. Because Pyadjoint is written to be extendable, CRIKit uses it as its main differentiation tool and takes advantage of other automatic differentiation tools by wrapping them into the Pyadjoint framework. For example, CRIKit has a class that interfaces between Pyadjoint and a TensorFlow graph so that Pyadjoint can drive the derivative calculations without knowing the specifics of TensorFlow. Documentation on the CRIKit extensions to Pyadjoint is described here: {ref}`Pyadjoint extension `. ## TensorFlow [TensorFlow][tensorflow] supports calculations on tensors and is most well-known for running neural networks and training them using backpropagation. The general method of using TensorFlow is to set up a static computation graph once and then evaluate the graph as many times as necessary. (This format in changed in TensorFlow 2.0, but CRIKit currently supports only versions 1.13 and 1.14). Below is some example code for setting up a basic neural network graph. ```{eval-rst} .. testcode:: import tensorflow.compat.v1 as tf import numpy as np from tensorflow.keras import layers g = tf.Graph() with g.as_default(): g_input = tf.placeholder(dtype=tf.float64, shape=(None, 2), name='input_data') z = g_input for layer_size in [40, 30, 20]: z = layers.Dense(layer_size, activation='tanh')(z) g_output = layers.Dense(1, activation='linear')(z) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) ``` Then the outputs of the graph can be evaluated by running the Session. ```{eval-rst} .. testcode:: # Run the graph on some random input. input = np.random.rand(15, 2) output = sess.run(g_output, {g_input: input}) ``` The function `pyadjoint_utils.tensorflow_adjoint.run_tensorflow_graph()` does the same computation as the `sess.run()` call above, but it also creates a Block for Pyadjoint to use to evaluate the derivative. The Block it creates simply calls the TensorFlow derivative routines whenever Pyadjoint needs the derivative. ```{eval-rst} .. testcode:: from pyadjoint_utils.tensorflow_adjoint import run_tensorflow_graph from pyadjoint.overloaded_type import create_overloaded_object as coo with sess.as_default(): pyadjoint_input = coo(input) output = run_tensorflow_graph(g_output, {g_input: pyadjoint_input}) ``` ## JAX [JAX][jax] provides "composable transformations of Python+NumPy programs", which include differentiation, vectorization, parallelization, and JIT compilation (including to GPU/TPU). CRIKit exposes JAX functionality to users via the {func}`pyadjoint_utils.jax_adjoint.overload_jax` function, whose usage looks like: .. testcode: ``` from jax import numpy as np import numpy as onp from pyadjoint_utils import array, overload_jax, ReducedFunction, taylor_test, Control from jax.tree_util import Partial as partial # or, from functools import partial #differentiate with respect to both non-self arguments. #Also, we want this function to recompute its internal linearization points #when differentiated, instead of storing them @partial(overload_jax,argnums=(0,1),checkpoint=True) def my_func(x, y): return np.exp(np.trace(x @ y)) inputs = (array(onp.random.randn(3,3)),array(onp.random.randn(3,3))) controls = [Control(x) for x in inputs] output = my_func(*inputs) rf = ReducedFunction(output,controls) h = [1.0e-1 * array(onp.random.randn(*x.shape) for x in inputs)] assert taylor_test(rf,inputs,h) >= 1.9 ``` ### Building From Source If efficiency is important, we highly recommend you build JAX from source. XLA tries to pre-compile as many ops as possible (e.g. matrix multiplication) so it can link the compiled code into JIT-ed functions at runtime for faster compilation. However, this means that, if your machine supports certain vector instructions that the `jaxlib` binary was not pre-compiled to support, you will not get the best possible performance out of the `jaxlib` installed through `pip`. We have observed that code generated by an XLA built locally on a machine can be more than twice as fast as the XLA installed through `pip` in CPU-only mode. Furthermore, if you have a GPU or TPU available, `pip` -installed JAX may not support compiling code for that GPU or TPU, so in order to use it, you'll have to build JAX from source. To do that, you first need to download JAX: ```{code-block} bash git clone https://github.com/google/jax cd jax ``` You'll also need its dependencies: a C++ compiler (any of `g++`, `clang++`, or `MSVC`), a working Python 3 installation and development files, and several Python packages. On a Debian system, the dependencies can be installed with ```{code-block} bash sudo apt install g++ python python3-dev ``` And on any system, the Python dependencies can be installed with ```{code-block} bash python -m pip install numpy scipy six wheel ``` To build CPU-only `jaxlib`, run ```{code-block} bash python build/build.py python -m pip install dist/*.whl ``` To build CUDA-enabled `jaxlib`, add the `--enable_cuda` flag to your build line (i.e. `python build/build.py --enable_cuda`), but otherwise do the same. To build with ROCm support, first ensure the following packages are installed: ```{code-block} bash rocm-dev miopen-hip rocfft rocblas hipsparse rocrand rocsolver hipblas rccl ``` then do ```{code-block} bash python build/build.py --enable_rocm --rocm_path /path/to/rocm-x.y.z python -m pip install dist/*.whl ``` If you're building a GPU-enabled `jaxlib` on Windows, we recommend you also read the [official JAX documentation on building from source.]() ### GPU Memory Allocation By default, JAX will preallocate 90% of currently-available GPU memory as soon as an op executes on the GPU in order to minimize overhead from allocations and memory fragmentation. However, this can cause out-of-memory error; if your GPU-enabled JAX process crashes with a segmentation fault shortly after starting, this is likely the reason. You can change this behavior by modifying the following environment variables, [described in more detail here.]() Setting `XLA_PYTHON_CLIENT_PREALLOCATE=false` will disable JAX's preallocation, instead allocating GPU memory as needed. Setting `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` sets the preallocation fraction to XX% instead of 90%. Setting `XLA_PYTHON_CLIENT_ALLOCATOR=platform` causes JAX to not only allocate GPU memory on demand, but also deallocate it when done (instead of re-using it for future computations). A common cause of JAX running out of memory is the case when you are running multiple MPI processes per available GPU. This can be fixed by setting `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX`, where `.XX = .90 / N` and `N` is the number of MPI processes per GPU. ## Autograd [Autograd][autograd] overloads [NumPy][numpy] functions so that their derivatives can be computed. Simply replace `import numpy` with `import autograd.numpy` to get access to derivatives through Autograd. Similarly to Pyadjoint and TensorFlow, the Autograd NumPy functions record all computations as nodes on a graph to facilitate calculating derivatives. Below is example code that defines a function doing some NumPy calculations. ```{eval-rst} .. testcode:: import autograd.numpy as np def f(u,g,p): mu = (np.sum(g*g,axis=1) + 1.e-12)**((p-2.)/2.) return np.reshape(mu,mu.shape+(1,)) * g ``` The `crikit.cr.point_map` decorator can be used to make a function that uses Autograd be compatible with Pyadjoint. ```{code} from crikit.cr.autograd import point_map @point_map(((-1,),(-1,2),()),(-1,2), bare=True) def f(u,g,p): mu = (np.sum(g*g,axis=1) + 1.e-12)**((p-2.)/2.) return np.reshape(mu,mu.shape+(1,)) * g ``` ```{eval-rst} .. todo:: 1. The block above is a "code" block instead of a "testcode" block because this branch doesn't actually have the autograd stuff in it, so the test would fail. 2. We should probably add support for using Autograd without having to use the CRIKit PointMap framework. ``` [pyadjoint]: http://www.dolfin-adjoint.org/en/release/ [tensorflow]: https://www.tensorflow.org/ [jax]: https://jax.readthedocs.io/en/latest [autograd]: https://github.com/HIPS/autograd [numpy]: https://numpy.org/