Pyadjoint Utils API#
Table of Contents
Core classes#
- class pyadjoint_utils.ReducedFunction(outputs, controls, tape=None, eval_cb_pre=None, eval_cb_post=None, jac_action_cb_pre=None, jac_action_cb_post=None, adj_jac_action_cb_pre=None, adj_jac_action_cb_post=None, hess_action_cb_pre=None, hess_action_cb_post=None)[source]#
Bases:
object
- class pyadjoint_utils.ReducedFunctionNumPy(reduced_function)[source]#
Bases:
pyadjoint_utils.reduced_function.ReducedFunctionThis class implements the reduced function for given function and controls based on numpy data structures.
This “NumPy version” of the ReducedFunction is created from an existing ReducedFunction object: rf_np = ReducedFunctionNumPy(rf = rf)
- __call__(m_array)[source]#
An implementation of the reduced function evaluation that accepts the control values as an array of scalars
- jac_action(m_array)[source]#
An implementation of the reduced function jac_action evaluation that accepts the controls as an array of scalars.
- adj_jac_action(adj_inputs)[source]#
An implementation of the reduced functional adjoint evaluation that returns the derivatives as an array of scalars.
- jac_matrix(m_jac=None)#
- class pyadjoint_utils.Tape(*args, **kwargs)[source]#
Bases:
pyadjoint.tape.Tape
- class pyadjoint_utils.Block(ad_block_tag=None)[source]#
Bases:
pyadjoint.block.Block- evaluate_tlm_matrix(markings=False)[source]#
Computes the tangent linear action and stores the result in the
tlm_matrixattribute of the outputs.This method will by default call the
evaluate_tlm_matrix_component()method for each output.- Parameters
markings (bool) – If True, then each block_variable will have set
marked_in_pathattribute indicating whether their tlm components are relevant for computing the final target tlm values. Default is False.
- evaluate_tlm_matrix_component(inputs, tlm_inputs, block_variable, idx, prepared=None)[source]#
This method should be overridden.
The method should implement a routine for computing the Jacobian of the block that corresponds to one output. Consider evaluating the Jacobian of a tape with n inputs and m outputs. A block on the tape has n’ inputs and m’ outputs. That block should take an n’ x n Jacobian as input, and it should output an m’ x n Jacobian. This function should return the Jacobian for just one output applied to the input Jacobian. The return value should be a list with n entries.
- Parameters
inputs (list) – A list of the saved input values, determined by the dependencies list.
tlm_inputs (list[list]) – The Jacobian of the inputs, determined by the dependencies list.
block_variable (pyadjoint_utils.BlockVariable) – The block variable of the output corresponding to index idx.
idx (int) – The index of the component to compute.
prepared (object) – Anything returned by the
prepare_evaluate_tlm_matrix()method. Default is None.
- Returns
The resulting product.
- Return type
An object of the same type as
block_variable.saved_output
- prepare_evaluate_tlm_matrix(inputs, tlm_inputs, relevant_outputs)[source]#
Runs preparations before
evaluate_tlm_matrix_component()is ran.The return value is supplied to each of the subsequent :meth`evaluate_tlm_matrix_component` calls. This method is intended to be overridden for blocks that require such preparations, by default there is none.
- Parameters
inputs – The values of the inputs
tlm_inputs – The tlm inputs
relevant_outputs – A list of the relevant block variables for
evaluate_tlm_matrix_component().
- Returns
Anything. The returned value is supplied to
evaluate_tlm_matrix_component()
- class pyadjoint_utils.BlockVariable(*args, **kwargs)[source]#
Bases:
pyadjoint.block_variable.BlockVariableReferences a block output variable.
- class pyadjoint_utils.Control(control)[source]#
Bases:
pyadjoint.control.Control
- class pyadjoint_utils.OverloadedType(*args, **kwargs)[source]#
Bases:
objectBase class for OverloadedType types.
The purpose of each OverloadedType is to extend a type such that it can be referenced by blocks as well as overload basic mathematical operations such as __mul__, __add__, where they are needed.
- class pyadjoint_utils.AdjFloat(*args, **kwargs)[source]#
Bases:
pyadjoint.adjfloat.AdjFloat
Core functions#
- pyadjoint_utils.push_tape()[source]#
Creates a new tape in its scope that is a sub-tape of the current working tape
- pyadjoint_utils.compute_gradient(J: Union[List[pyadjoint.overloaded_type.OverloadedType], pyadjoint.overloaded_type.OverloadedType], m: Union[List[pyadjoint_utils.control.Control], pyadjoint_utils.control.Control], options: Optional[dict] = None, tape: Optional[pyadjoint_utils.tape.Tape] = None, adj_value: float = 1.0) Union[List[pyadjoint.overloaded_type.OverloadedType], pyadjoint.overloaded_type.OverloadedType][source]#
Compute the gradient of J with respect to the initialisation value of m, that is the value of m at its creation.
- Parameters
J (OverloadedType, list[OverloadedType]) – The objective functional.
options (dict) – A dictionary of options. To find a list of available options have a look at the specific control type.
tape (Tape) – The tape to use. Default is the current tape.
- Returns
- The derivative with respect to the control. Should be an instance of the same type as
the control.
- Return type
- pyadjoint_utils.compute_jacobian_matrix(J: Union[List[pyadjoint.overloaded_type.OverloadedType], pyadjoint.overloaded_type.OverloadedType], m: Union[List[pyadjoint_utils.control.Control], pyadjoint_utils.control.Control], m_jac: Optional[Any] = None, tape: Optional[pyadjoint_utils.tape.Tape] = None) Any[source]#
Compute dJdm matrix.
- Parameters
J (OverloadedType) – The outputs of the function.
m (list[pyadjoint_utils.Control] or pyadjoint_utils.Control) – The (list of) controls.
m_jac – An input Jacobian to multiply with. By default, this will be an identity Jacobian. If m is a list, this should be a list of lists with len(m_jac) == len(m) and len(m_jac[i]) == len(m) for each i-th entry in m_jac.
tape – The tape to use. Default is the current tape.
- Returns
The jacobian with respect to the control. Should be an instance of the same type as the control.
- Return type
- pyadjoint_utils.compute_jacobian_action(J: Union[List[pyadjoint.overloaded_type.OverloadedType], pyadjoint.overloaded_type.OverloadedType], m: Union[List[pyadjoint_utils.control.Control], pyadjoint_utils.control.Control], m_dot: Union[List[pyadjoint.overloaded_type.OverloadedType], pyadjoint.overloaded_type.OverloadedType], options: Optional[dict] = None, tape: Optional[pyadjoint_utils.tape.Tape] = None) Union[List[pyadjoint.overloaded_type.OverloadedType], pyadjoint.overloaded_type.OverloadedType][source]#
Compute the action of the Jacobian of J on m_dot with respect to the initialisation value of m, that is the value of m at its creation.
- Parameters
J (OverloadedType) – The outputs of the function.
m (list[pyadjoint.Control] or pyadjoint.Control) – The (list of) controls.
options (dict) – A dictionary of options. To find a list of available options have a look at the specific control type.
tape – The tape to use. Default is the current tape.
m_dot (OverloadedType) – variation of same overloaded type as m.
- Returns
The action on m_dot of the Jacobian of J with respect to the control. Should be an instance of the same type as the output of J.
- Return type
## Callbacks
```{eval-rst} .. autoclass:: Callback
- show-inheritance
Numpy Backends#
- pyadjoint_utils.numpy_backend.get_default_backend()[source]#
Returns the default CRIKit numpy backend :return: The default numpy backend :rtype NumpyBackend:
- pyadjoint_utils.numpy_backend.get_backend(which: Optional[str] = None)[source]#
Returns a numpy backend corresponding to the string ‘numpy’, ‘jax’, ‘torch’, or the default backend with None
- Parameters
which (str, optional) – Which backend to set, defaults to None
- Returns
The default numpy backend
- Rtype NumpyBackend
Optimization#
- pyadjoint_utils.minimize.minimize(rf, method='L-BFGS-B', scale=1.0, callbacks=<pyadjoint_utils.minimize._default object>, **kwargs)[source]#
CRIKit’s wrapper around pyadjoint’s minimize function that itself calls scipy’s minimization routines. See pyadjoint’s documentation for information on kwargs, which are passed through to pyadjoint’s minimize function.
- Parameters
to combine into one callback, or a single such callback function, or None if no callback is desired. Defaults to a {class}`FileLoggerCallback` with default arguments :type callbacks: Optional[Sequence[Union[Callable, Callback]]]
FEniCS adjoint#
- pyadjoint_utils.fenics_adjoint.function_get_local(*args, **kwargs)#
- pyadjoint_utils.fenics_adjoint.function_set_local(*args, **kwargs)#
- pyadjoint_utils.fenics_adjoint.assemble(*args, **kwargs)[source]#
Use the assemble syntax of firedrake, where the ‘tensor’ kwarg can take a Function. If that is the case (or if firedrake is the backend and the returned tensor is a Function), convert the result to an overloaded function so that an AssembleBlock can be created, even if the output is not a scalar.
JAX adjoint#
- class pyadjoint_utils.jax_adjoint.ndarray(obj: Any, *args, **kwargs)[source]#
Bases:
pyadjoint.overloaded_type.OverloadedType- property T#
Returns the transpose of this ndarray
- __init__(obj: Any, *args, **kwargs)[source]#
Note: you should not typically use this constructor directly in your code. Instead, you should call
array()orasarray(), which will call the constructor of this class when appropriate.- Parameters
obj (jax.interpreters.xla.DeviceArray) – the object to wrap; should be either a JAX ndarray or something that can be converted to one (such as a list or tuple of floats or ints, or a float or int, or a numpy ndarray)
- Returns
a class that wraps a JAX array (such that it can be added to the JAX Pytree and thus used as an argument to a differentiable function) to be passed to a function wrapped with overload_jax(), while also inheriting from pyadjoint.OverloadedType (since you can’t inherit from a JAX array; see google/jax#4269).
- Return type
- flatten() jax._src.numpy.lax_numpy.ndarray[source]#
Returns a flattened 1-d array (NOT an ndarray, but rather the array type it contains)
- property shape#
The shape of the array
- property size#
How many elements does the array contain?
- tree_flatten() Tuple[Tuple[jax._src.numpy.lax_numpy.ndarray, ...], None][source]#
Flattens an ndarray in the JAX Pytree structure
- Returns
tuple containing any arrays (or other children) this ndarray holds, and an empty metadata field
- Return type
- classmethod tree_unflatten(aux_data, children)[source]#
Constructs an ndarray from its flattened components
- unwrap(to_jax: bool = True) jax._src.numpy.lax_numpy.ndarray[source]#
If this ndarray holds recursively nested ndarrays (e.g. its __repr__() is ndarray(ndarray(…))), unwrap until it holds the array data contained in the deepest-nested ndarray.
This is mostly a utility for use in jacfwd and jacrev in pyadjoint_utils/numpy_adjoint/jax.py
- Parameters
to_jax (bool, optional) – go one level further and return the raw JAX array (instead of ndarray, the OverloadedType wrapper)?, defaults to False
- Returns
unwrapped version of self
- Return type
jax.interpreters.xla.DeviceArray
- pyadjoint_utils.jax_adjoint.array(obj: Any, **kwargs) pyadjoint_utils.jax_adjoint.array.ndarray[source]#
- Converts the input to an
ndarray. This function is NOT overloaded (does not add any
Blockto theTape). If you want to convert anAdjFloatto anndarrayor vice-versa, use the functionsto_jax()orto_adjfloat()respectively.
- Parameters
obj (Union[jax.interpreters.xla.DeviceArray,Iterable[Union[int,float,jax.interpreters.xla.DeviceArray]]]) – the object to wrap; should be either a JAX ndarray or something that can be converted to one (such as a list or tuple of floats or ints, or a float or int, or a numpy ndarray)
- Returns
a class that wraps a JAX array (such that it can be added to the JAX Pytree and thus used as an argument to a differentiable function) to be passed to a function wrapped with overload_jax(), while also inheriting from pyadjoint.OverloadedType (since you can’t inherit from a JAX array; see google/jax#4269).
- Return type
- Converts the input to an
- pyadjoint_utils.jax_adjoint.asarray(obj: Any, **kwargs) pyadjoint_utils.jax_adjoint.array.ndarray[source]#
- pyadjoint_utils.jax_adjoint.overload_jax(func: pyadjoint_utils.jax_adjoint.jax.Function, static_argnums: Optional[Union[int, Iterable[int]]] = None, argnums: Optional[Union[int, Iterable[int]]] = None, jit: Optional[bool] = True, function_name: Optional[str] = None, checkpoint: bool = False, concrete: bool = False, backend: Optional[str] = None, donate_argnums: Optional[Union[int, Iterable[int]]] = None, pointwise: Union[bool, Sequence[bool]] = False, out_pointwise: Optional[Union[bool, Sequence[bool]]] = None, **jax_kwargs) pyadjoint_utils.jax_adjoint.jax.OverloadedFunction[source]#
Creates a pyadjoint-overloaded version of a JAX-traceable function.
- Parameters
func (Function) – The function to JIT compile and make differentiable
static_argnums (Union[int,Iterable[int]], optional) – The static_argnums parameter of jax.jit (e.g. numbers of arguments that, if changed, should trigger recompilation)
argnums (Union[int,Iterable[int]], optional) – The numbers of the arguments you want to differentiate with respect to. For example, if you have a function f(x,p,w) and want the derivative with respect to p and w, pass argnums=(1,2).
jit (bool, optional) – If True, do JIT compile the function, defaults to True
function_name (str, optional) – if you want the function’s name on the JAXBlock recorded as something other than func.__name__, use this parameter
checkpoint (bool, optional) – if True, make func recompute internal linearization points when differentiated (as opposed to computing these in the forward pass and storing the results). This increases total FLOPs in exchange for less memory usage/fewer acceses, defaults to False
concrete (bool, optional) – if True, indicates that the function requires value-dependent Python control flow, defaults to False
backend (str, optional) – String representing the XLA backend to use (e.g. ‘cpu’, ‘gpu’, ‘tpu’). (Note that this is an experimental JAX feature, and its API is likely to change), defaults to None
donate_argnums (Union[int,Iterable[int]], optional) – Which arguments are ‘donated’ to the computation? In other words, you cannot reuse these arguments after calling the function. This lets XLA more aggresively re-use donated buffers., default None
pointwise (Union[bool,Sequence[bool]], optional) – By default, this is false. True means the function performs operations on a batch of points. This allows optimizing the Jacobian calculations by only computing the diagonal component. If a list, then there should be a bool for each argnum.
out_pointwise (Union[bool,Sequence[bool]], optional) – If any inputs are defined pointwise, this specifies which outputs are defined pointwise. If a list, then there should be abool for each output. By default, all outputs will be assumed pointwise if any inputs are pointwise.
- Returns
A function with the same signature as func that performs the same computation, but is differentiable by both JAX and Pyadjoint and possibly JIT-compiled by JAX
- Return type
OverloadedFunction
NumPy adjoint#
- pyadjoint_utils.numpy_adjoint.autograd.overload_autograd(func, pointwise)[source]#
Create an overloaded version of an Autograd function.
This method makes several assumptions: 1) The function is explicit, i.e.
y = func(x), where y is the output of the operation. 2) All of y is possible to convert to an OverloadedType. 3) Unless annotation is turned off, the operation should always be annotated when calling the overloaded function.After the overloaded function is called, the pointwise bool for each input is recorded in the function’s pointwise attribute.
- Parameters
func (function) – The target function for which to create an overloaded version.
pointwise (bool or list[bool]) – True means the function performs operations on a batch of points. This allows optimizing the Jacobian calculations by only computing the diagonal component. If a list, then there should be a bool for each input.
- Returns
An overloaded version of
func- Return type
function