import jax
import jax.numpy as jnp
from pyadjoint_utils.jax_adjoint import ndarray, array
from pyadjoint import AdjFloat
from pyadjoint.enlisting import Enlist
from pyadjoint_utils import Block, JacobianIdentity
from pyadjoint import Control
from pyadjoint.overloaded_type import (
OverloadedType,
register_overloaded_type,
create_overloaded_object,
)
from pyadjoint.overloaded_function import overload_function
from pyadjoint.tape import annotate_tape, get_working_tape, stop_annotating
from functools import wraps
from jax.tree_util import Partial as partial # JAX-friendlier functools.partial
from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, tree_map
from itertools import product
from typing import Tuple, Optional, Callable, Iterable, Sequence, Union, Any, TypeVar
from .array import convert_arg
import logging
logger = logging.getLogger("CRIKit")
# These TypeVars express the fact that
# the overload_jax() function is a transformation
# from a general function to a pyadjoint-overloaded function
# (i.e. one that can take in OverloadedObject instances as inputs
# and return them as outputs, correctly taping the operation they represent)
Function = TypeVar("Function", bound=Callable)
OverloadedFunction = TypeVar("OverloadedFunction", bound=Callable)
class JAXBlock(Block):
def __init__(
self,
func,
args,
outputs,
argnums=None,
nojit=False,
jit_kwargs=None,
pointwise=False,
out_pointwise=None,
**kwargs,
):
super().__init__()
self._func_name = kwargs.get("func_name", None) or func.__name__
self._holomorphic = kwargs.get("holomorphic", False)
self._func = func
self._differentiable_func = self._diff_func
self._nojit = nojit
self._jit_kwargs = jit_kwargs if jit_kwargs is not None else {}
num_inputs = len(args)
num_outputs = len(outputs)
if argnums:
self._tlm = jax.jacrev(
self._differentiable_func,
argnums=argnums,
holomorphic=self._holomorphic,
)
else:
self._tlm = jax.jacrev(
self._differentiable_func, holomorphic=self._holomorphic
)
self._jvp_maker = _pushforward(self._differentiable_func)
argnum_range = tuple(argnums) if argnums else range(num_inputs)
self._tlm_argnums = argnum_range
self._tlm_outnums = tuple(range(num_outputs))
self._tlm_matrix_func = self._jit(
jax.jacrev(self._diff_func, argnums=self._tlm_argnums)
)
self._pw_tlm_matrix_func = None
self._vjpfuncs = [
vector_jacobian_product(self._func, i, jitfun=self._jit)
for i in argnum_range
]
self._hesfuncs = [
vector_jacobian_product(self._vjpfuncs[i], j, False, jitfun=self._jit)
for i, j in product(range(len(argnum_range)), argnum_range)
]
iarange = range(len(argnum_range))
self._make_hvpfun = _make_hvpfun(self._differentiable_func)
self._hessian = jax.hessian(
self._differentiable_func, argnums=range(len(argnum_range))
)
self._dependencies = []
self._outputs = []
self._param_range = argnum_range
# flatten args and add all dependencies. The input signature of the function is
# typically something like ((arg1,arg2,...),param1,param2,...), so JAX's pytree
# utilities are helpful for this stuff
arg_leaves = []
for a in args:
if isinstance(a, (tuple, list)):
for v in a:
arg_leaves.append(v)
else:
arg_leaves.append(a)
if len(self._param_range) > len(arg_leaves):
raise ValueError(
f"There are more argument indices in argnums than there are arguments ({len(self._param_range)} > {len(arg_leaves)})"
)
if max(self._param_range) >= len(arg_leaves):
raise ValueError(
f"There is an argument index in argnums that is higher than the number of arguments ({max(self._param_range)} >= {len(arg_leaves)})"
)
# Record dependencies only for the specified arguments.
for idx in self._param_range:
self.add_dependency(arg_leaves[idx])
leaves, self._treedef = tree_flatten(tuple(args))
self._args = [convert_arg(x) for x in leaves]
for out in outputs:
self.add_output(out.create_block_variable())
self._saved_outs = outputs
self._single_output = num_outputs == 1
# Handle pointwise specifications.
self._in_pointwise = (
(pointwise,) * num_inputs if isinstance(pointwise, bool) else pointwise
)
self._any_pointwise = any(self._in_pointwise)
if out_pointwise is None:
out_pointwise = self._any_pointwise
self._out_pointwise = (
(out_pointwise,) * num_outputs
if isinstance(out_pointwise, bool)
else out_pointwise
)
if len(self._in_pointwise) != num_inputs:
raise ValueError(
f"len(pointwise) != num_inputs. For each input, you must specify if it is defined pointwise. ({len(self._in_pointwise)} != {num_inputs})"
)
if len(self._out_pointwise) != num_outputs:
raise ValueError(
f"len(pointwise) != num_outputs. For each output, you must specify if it is defined pointwise. ({len(self._out_pointwise)} != {num_outputs})"
)
if any(self._in_pointwise) and not any(self._out_pointwise):
raise ValueError(
"There is a pointwise input and no pointwise outputs. If an input is specified as pointwise, then at least one output must be specified as pointwise."
)
if any(self._out_pointwise) and not any(self._in_pointwise):
raise ValueError(
"There is a pointwise output and no pointwise inputs. If an output is specified as pointwise, then at least one input must be specified as pointwise."
)
def _jit(self, func):
return func if self._nojit else jax.jit(func, **self._jit_kwargs)
def _diff_func(self, *args):
leaves, treedef = tree_flatten(tuple(args))
self._replace_params(leaves)
argtree = tree_unflatten(self._treedef, self._args)
args = tuple(convert_arg(x) for x in argtree)
return self._func(*args)
def _make_partial_diff_func(self, idxs):
def _pdiff_func(*args):
leaves, treedef = tree_flatten(tuple(args))
param_idxs = [self._param_range[idx] for idx in idxs]
self._replace_params_partial(args, param_idxs)
argtree = tree_unflatten(self._treedef, self._args)
args = tuple(convert_arg(x) for x in argtree)
return self._func(*args)
return _pdiff_func
def _replace_params(self, new_params):
for new_param, idx in zip(new_params, self._param_range):
par = convert_arg(new_param)
self._args[idx] = par
def _replace_params_partial(self, new_params, idxs):
for i, idx in enumerate(idxs):
self._args[idx] = convert_arg(new_params[i])
def _get_reduced_output_diff_func(self, output_ids):
if self._single_output:
return self._diff_func
def diff_func_relevant(*args):
outputs = self._diff_func(*args)
outputs = tuple(outputs[idx] for idx in output_ids)
return outputs
return diff_func_relevant
def _add_params(self, deltas):
for i, idx in enumerate(self._param_range):
self._args[idx] += convert_arg(deltas[i])
def __repr__(self):
return f"JAXBlock({self._func_name})"
def prepare_recompute_component(self, inputs, relevant_outputs):
return Enlist(self._diff_func(*inputs))
def recompute_component(self, inputs, block_variable, idx, prepared):
if prepared:
return prepared[idx]
elif self._saved_outs:
return self._saved_outs[idx]
else:
self._replace_params(inputs)
self._saved_outs = Enlist(self._differentiable_func(*self._args))
return self._saved_outs[idx]
def prepare_evaluate_hessian(
self, inputs, hessian_inputs, adj_inputs, relevant_dependencies
):
raise NotImplementedError
tlm_inputs = [dep.tlm_value for dep in self.get_dependencies()]
out_idx = []
ipts = []
for i, ip in enumerate(inputs):
if ip is not None:
out_idx.append(i)
ipts.append(convert_arg(ip))
N = len(self._vjpfuncs)
hvp = [None] * N
for (i, j), f in zip(
product(self._tlm_argnums, self._tlm_argnums), self._hesfuncs
):
x_dot = convert_arg(tlm_inputs[i])
if x_dot is None:
continue
tmp = f(*ipts, *tuple(convert_arg(x) for x in adj_inputs), x_dot)
if hvp[j] is None:
hvp[j] = tmp
else:
hvp[j] = hvp[j] + tmp
return hvp
def evaluate_hessian_component(
self,
inputs,
hessian_inputs,
adj_inputs,
block_variable,
idx,
relevant_dependencies,
prepared=None,
):
return prepared[idx]
def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs):
out_idx = []
ipts = []
tips = []
for i, (ip, ti) in enumerate(zip(inputs, tlm_inputs)):
if ti is not None:
out_idx.append(i)
ipts.append(convert_arg(ip))
tips.append(convert_arg(ti))
jvp_maker = _pushforward(self._make_partial_diff_func(out_idx))
jvpfun = jvp_maker(ipts)
val = Enlist(jvpfun(*tips))
return val
def evaluate_tlm_component(
self, inputs, tlm_inputs, block_variable, idx, prepared=None
):
return prepared[idx]
def prepare_evaluate_tlm_matrix(self, inputs, tlm_inputs, relevant_outputs):
# Standardize argnums as collection to ensure return type from jacrev is also a collection
argnums = []
for i, x in enumerate(tlm_inputs):
if x is not None:
argnums.append(i)
for j, di_dj in enumerate(x):
if (i != j and di_dj is not None) or (
i == j and not isinstance(di_dj, JacobianIdentity)
):
raise NotImplementedError(
"Non-identity inputs cannot be handled yet."
)
outnums = [idx for idx, bv in relevant_outputs]
recompute_jac = (tuple(argnums) != tuple(self._tlm_argnums)) or (
tuple(outnums) != tuple(self._tlm_outnums)
)
if recompute_jac:
self._tlm_argnums = argnums
diff_func_relevant = self._get_reduced_output_diff_func(outnums)
self._tlm_matrix_func = self._jit(
jax.jacrev(diff_func_relevant, argnums=argnums)
)
cargs = tuple(convert_arg(x) for x in inputs)
if not self._any_pointwise:
val = self._tlm_matrix_func(*cargs)
if self._single_output:
return (val,)
# Expand reduced outputs back to the expected length.
rv = [None] * len(self.get_outputs())
for i, idx in enumerate(outnums):
rv[idx] = val[i]
return rv
# Set up all arguments into a standard format that can be looped through pointwise.
outputs = [bv.saved_output for bv in self.get_outputs()]
out_shapes = [
a.shape[1:] if pw else a.shape
for a, pw in zip(outputs, self._out_pointwise)
]
n = max(
[a.shape[0] for a, pw in zip(cargs, self._in_pointwise) if pw]
+ [a.shape[0] for a, pw in zip(outputs, self._out_pointwise) if pw]
)
# Pad extra axis for cases where internal func is already vmapped
standard_args = []
for i, (a, pw) in enumerate(zip(cargs, self._in_pointwise)):
if pw:
if a.shape[0] not in (1, n):
raise ValueError(
f"Argument {i} is marked as pointwise but first axis doesn't match expected size ({a.shape[0]} != {n})"
)
s = jnp.reshape(a, (a.shape[0], 1, *a.shape[1:]))
s = jnp.broadcast_to(s, (n, *s.shape[1:]))
else:
s = a
standard_args.append(s)
# v is a tuple of size n. Each one is a tuple of size num_relevant_outputs. Each entry in that tuple is a tuple of size num_relevant_inputs.
doutput_dinput = [None] * len(self.get_outputs())
squeeze_didj = jax.jit(partial(jnp.squeeze, axis=1))
# Build vmap spec based on input format
tree_def = jax.tree_structure(standard_args)
in_axes = jax.tree_util.build_tree(
tree_def, [0 if a else None for a in self._in_pointwise]
)
# It doesn't seem possible to use a more fine-grained vmap for un-vmapped outputs since the Jacobian is
# defined for any vmapped inputs even if the Jacobian is always going to be zero.
out_axes = tuple(tuple(0 for j in argnums) for i in outnums)
if self._single_output:
out_axes = out_axes[0]
# Generate block-wise jacobians in a batch then post-process
if self._pw_tlm_matrix_func is None or recompute_jac:
self._pw_tlm_matrix_func = self._jit(
jax.vmap(self._tlm_matrix_func, in_axes=in_axes, out_axes=out_axes)
)
jac_blocks = self._pw_tlm_matrix_func(*standard_args)
# Singleton outputs need to be padded
if self._single_output:
jac_blocks = [jac_blocks]
# Iterate through and squeeze out unecessary axes generated by axis padding to match other backends.
# It's probably possible to standardize this so we neither pad or squeeze,
# but it's hard to see how without requiring the user to specify whether the inner function is already vmapped
for i, out_idx in enumerate(outnums):
out_pw = self._out_pointwise[out_idx]
out_rank = len(out_shapes[out_idx])
squeeze_didj_pw = jax.jit(partial(jnp.squeeze, axis=(1, 2 + out_rank)))
squeeze_didj_in_pw = jax.jit(partial(jnp.squeeze, axis=1 + out_rank))
di_dinput = [None] * len(argnums)
for j, in_idx in enumerate(argnums):
in_pw = self._in_pointwise[in_idx]
di_dj = jac_blocks[i][j]
if out_pw:
if in_pw:
# out_i: (n, *output_shape)
# arg_j: (n, *input_shape)
# di_dj: (n, 1, *output_shape, 1, *input_shape)
# di_dj': (n, *output_shape, *input_shape)
di_dj = squeeze_didj_pw(di_dj)
else:
# out_i: (n, *output_shape)
# arg_j: (*input_shape)
# di_dj: (n, 1, *output_shape, *input_shape)
# di_dj': (n, *output_shape, *input_shape)
di_dj = squeeze_didj(di_dj)
elif in_pw:
# out_i: (*output_shape)
# arg_j: (n, *input_shape)
# di_dj: (*output_shape, n, 1, *input_shape)
# di_dj': (*output_shape, n, *input_shape)
di_dj = squeeze_didj_in_pw(di_dj)
else:
# If neither is defined pointwise, the shapes should be this:
# out_i: (*output_shape)
# arg_j: (*input_shape)
# di_dj: (n, *output_shape, *input_shape)
# di_dj': (*output_shape, *input_shape)
di_dj = di_dj[0, ...]
di_dinput[j] = di_dj
doutput_dinput[out_idx] = di_dinput
return doutput_dinput
def evaluate_tlm_matrix_component(
self, inputs, tlm_inputs, block_variable, idx, prepared=None
):
return prepared[idx]
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
relevant_outputs_idxs = [
idx for idx, adj in enumerate(adj_inputs) if adj is not None
]
diff_func_relevant = self._get_reduced_output_diff_func(relevant_outputs_idxs)
outs, fvjp = jax.vjp(diff_func_relevant, *tuple(convert_arg(x) for x in inputs))
fvjp = self._jit(fvjp)
adj_inputs = tuple(
convert_arg(adj_inputs[idx]) for idx in relevant_outputs_idxs
)
if self._single_output:
adj_inputs = adj_inputs[0]
return Enlist(fvjp(adj_inputs))
def evaluate_adj_component(
self, inputs, adj_inputs, block_variable, idx, prepared=None
):
return prepared[idx]
def _pushforward(func):
def prepare_jvp(inputs):
outputs, jvp_func = jax.linearize(func, *inputs)
return jax.jit(jvp_func)
return prepare_jvp
def _pullback(func, argnum):
def prepare_vjp(inputs, adj_inputs):
outputs, vjp_func = jax.vjp(func, *inputs)
vjp_func = jax.jit(vjp_func)
return vjp_func(*adj_inputs)
def eval_vjp_component(prepared):
return prepared[argnum]
return prepare_vjp, eval_vjp_component
def _make_hvpfun(f):
def make_hvp(inputs, adj_inputs):
outs, fvjp = jax.vjp(f, *tuple(convert_arg(x) for x in inputs))
outs, fhvp = jax.linearize(
jax.jit(fvjp), *tuple(convert_arg(x) for x in adj_inputs)
)
return jax.jit(fhvp)
return make_hvp
def vector_jacobian_product(fun, argnum, reverse=True, jitfun=jax.jit):
# based on the old autograd implementation
def vec_prod(*args, **kwargs):
args, x_dot = args[:-1], args[-1]
return jnp.tensordot(x_dot, fun(*args, **kwargs), axes=jnp.ndim(x_dot))
if reverse:
return jax.jacrev(vec_prod, argnums=argnum)
return jitfun(jax.jacfwd(vec_prod, argnums=argnum))
def is_tracer(x):
return isinstance(x, jax.core.Tracer)
def get_overloaded(x: Any) -> OverloadedType:
if isinstance(x, (ndarray, tuple)):
return x
return create_overloaded_object(x)
[docs]def overload_jax(
func: Function,
static_argnums: Optional[Union[int, Iterable[int]]] = None,
argnums: Optional[Union[int, Iterable[int]]] = None,
jit: Optional[bool] = True,
function_name: Optional[str] = None,
checkpoint: bool = False,
concrete: bool = False,
backend: Optional[str] = None,
donate_argnums: Optional[Union[int, Iterable[int]]] = None,
pointwise: Union[bool, Sequence[bool]] = False,
out_pointwise: Optional[Union[bool, Sequence[bool]]] = None,
**jax_kwargs,
) -> OverloadedFunction:
"""
Creates a pyadjoint-overloaded version of a JAX-traceable function.
:param func: The function to JIT compile and make differentiable
:type func: Function
:param static_argnums: The static_argnums parameter of jax.jit (e.g. numbers
of arguments that, if changed, should trigger recompilation)
:type static_argnums: Union[int,Iterable[int]], optional
:param argnums: The numbers of the arguments you want to differentiate
with respect to. For example, if you have a function f(x,p,w) and
want the derivative with respect to p and w, pass argnums=(1,2).
:type argnums: Union[int,Iterable[int]], optional
:param jit: If True, do JIT compile the function, defaults to True
:type jit: bool, optional
:param function_name: if you want the function's name on the JAXBlock recorded as
something other than func.__name__, use this parameter
:type function_name: str, optional
:param checkpoint: if True, make `func` recompute internal linearization
points when differentiated (as opposed to computing these in the forward
pass and storing the results). This increases total FLOPs in exchange for
less memory usage/fewer acceses, defaults to False
:type checkpoint: bool, optional
:param concrete: if True, indicates that the function requires value-dependent
Python control flow, defaults to False
:type concrete: bool, optional
:param backend: String representing the XLA backend to use (e.g. 'cpu', 'gpu',
'tpu'). (Note that this is an experimental JAX feature, and its API is
likely to change), defaults to None
:type backend: str, optional
:param donate_argnums: Which arguments are 'donated' to the computation? In other
words, you cannot reuse these arguments after calling the function. This
lets XLA more aggresively re-use donated buffers., default None
:type donate_argnums: Union[int,Iterable[int]], optional
:param pointwise: By default, this is false. True means the function performs
operations on a batch of points. This allows optimizing the Jacobian calculations
by only computing the diagonal component. If a list, then there should be a bool
for each argnum.
:type pointwise: Union[bool,Sequence[bool]], optional
:param out_pointwise: If any inputs are defined pointwise, this specifies which
outputs are defined pointwise. If a list, then there should be abool for each
output. By default, all outputs will be assumed pointwise if any inputs are
pointwise.
:type out_pointwise: Union[bool,Sequence[bool]], optional
:return: A function with the same signature as `func` that performs the same
computation, but is differentiable by both JAX and Pyadjoint and possibly
JIT-compiled by JAX
:rtype: OverloadedFunction
"""
# handle deprecated API
if "nojit" in jax_kwargs:
logger.critical(
"the `nojit` parameter to overload_jax() has been "
"replaced by `jit`, which takes the opposite value "
"(i.e. `nojit=True` should be replaced with `jit=False` "
"and `nojit=False` should be replaced with `jit=True`"
)
raise TypeError("overload_jax() got an unexpected keyword argument 'nojit'")
if checkpoint:
func = jax.checkpoint(func, concrete=concrete)
# first, jit compile
if jit:
static_argnums = static_argnums or ()
donate_argnums = donate_argnums or ()
jit_kwargs = dict(
static_argnums=static_argnums,
backend=backend,
donate_argnums=donate_argnums,
**jax_kwargs,
)
func = jax.jit(func, **jit_kwargs)
else:
jit_kwargs = None
# now overload
@wraps(func)
def _overloaded_func(*args, **kwargs):
annotate = annotate_tape(kwargs)
with stop_annotating():
cargs = tuple(convert_arg(x) for x in args)
out = func(*cargs, **kwargs)
if is_tracer(out):
return out
for arg in cargs:
if is_tracer(arg):
return out
out = Enlist(out)
overloads = [create_overloaded_object(arr) for arr in out]
if annotate:
tape = get_working_tape()
kwargs["func_name"] = function_name
block = JAXBlock(
func,
args,
overloads,
argnums=argnums,
nojit=(not jit),
jit_kwargs=jit_kwargs,
pointwise=pointwise,
out_pointwise=out_pointwise,
**kwargs,
)
tape.add_block(block)
return out.delist(overloads)
return _overloaded_func