## Core classes¶

Bases: object

__call__(inputs)[source]
jac_action(inputs, options=None)[source]
jac_matrix(m_jac=None)[source]

This class implements the reduced function for given function and controls based on numpy data structures.

This “NumPy version” of the ReducedFunction is created from an existing ReducedFunction object: rf_np = ReducedFunctionNumPy(rf = rf)

__call__(m_array)[source]

An implementation of the reduced function evaluation that accepts the control values as an array of scalars

jac_action(m_array)[source]

An implementation of the reduced function jac_action evaluation that accepts the controls as an array of scalars.

An implementation of the reduced functional adjoint evaluation that returns the derivatives as an array of scalars.

jac_matrix(m_jac=None)

Bases: pyadjoint.tape.Tape

Bases: pyadjoint.block.Block

evaluate_tlm_matrix(markings=False)[source]

Computes the tangent linear action and stores the result in the tlm_matrix attribute of the outputs.

This method will by default call the evaluate_tlm_matrix_component() method for each output.

Parameters

markings (bool) – If True, then each block_variable will have set marked_in_path attribute indicating whether their tlm components are relevant for computing the final target tlm values. Default is False.

evaluate_tlm_matrix_component(inputs, tlm_inputs, block_variable, idx, prepared=None)[source]

This method should be overridden.

The method should implement a routine for computing the Jacobian of the block that corresponds to one output. Consider evaluating the Jacobian of a tape with n inputs and m outputs. A block on the tape has n’ inputs and m’ outputs. That block should take an n’ x n Jacobian as input, and it should output an m’ x n Jacobian. This function should return the Jacobian for just one output applied to the input Jacobian. The return value should be a list with n entries.

Parameters
Returns

The resulting product.

Return type

An object of the same type as block_variable.saved_output

prepare_evaluate_tlm_matrix(inputs, tlm_inputs, relevant_outputs)[source]

Runs preparations before evaluate_tlm_matrix_component() is ran.

The return value is supplied to each of the subsequent :methevaluate_tlm_matrix_component calls. This method is intended to be overridden for blocks that require such preparations, by default there is none.

Parameters
Returns

Anything. The returned value is supplied to evaluate_tlm_matrix_component()

References a block output variable.

Bases: pyadjoint.control.Control

Bases: object

The purpose of each OverloadedType is to extend a type such that it can be referenced by blocks as well as overload basic mathematical operations such as __mul__, __add__, where they are needed.

Bases: pyadjoint.adjfloat.AdjFloat

Bases: object

A class that encapsulates all the information required to formulate a reduced equation solve problem.

## Core functions¶

Creates a new tape in its scope that is a sub-tape of the current working tape

Compute the gradient of J with respect to the initialisation value of m, that is the value of m at its creation.

Parameters

• m (Union[list[Control], Control]) – The (list of) controls.

• options (dict) – A dictionary of options. To find a list of available options have a look at the specific control type.

• tape (Tape) – The tape to use. Default is the current tape.

Returns

The derivative with respect to the control. Should be an instance of the same type as

the control.

Return type

pyadjoint_utils.compute_jacobian_matrix(J: , m: , m_jac: Optional[Any] = None, tape: Optional[pyadjoint_utils.tape.Tape] = None) Any[source]

Compute dJdm matrix.

Parameters
• J (OverloadedType) – The outputs of the function.

• m_jac – An input Jacobian to multiply with. By default, this will be an identity Jacobian. If m is a list, this should be a list of lists with len(m_jac) == len(m) and len(m_jac[i]) == len(m) for each i-th entry in m_jac.

• tape – The tape to use. Default is the current tape.

Returns

The jacobian with respect to the control. Should be an instance of the same type as the control.

Return type

pyadjoint_utils.compute_jacobian_action(J: , m: , m_dot: , options: Optional[dict] = None, tape: Optional[pyadjoint_utils.tape.Tape] = None) [source]

Compute the action of the Jacobian of J on m_dot with respect to the initialisation value of m, that is the value of m at its creation.

Parameters
• J (OverloadedType) – The outputs of the function.

• options (dict) – A dictionary of options. To find a list of available options have a look at the specific control type.

• tape – The tape to use. Default is the current tape.

Returns

The action on m_dot of the Jacobian of J with respect to the control. Should be an instance of the same type as the output of J.

Return type

Use the assemble syntax of firedrake, where the ‘tensor’ kwarg can take a Function. If that is the case (or if firedrake is the backend and the returned tensor is a Function), convert the result to an overloaded function so that an AssembleBlock can be created, even if the output is not a scalar.

property T

Returns the transpose of this ndarray

__init__(obj: Any, *args, **kwargs)[source]

Note: you should not typically use this constructor directly in your code. Instead, you should call array() or asarray(), which will call the constructor of this class when appropriate.

Parameters

obj (jax.interpreters.xla.DeviceArray) – the object to wrap; should be either a JAX ndarray or something that can be converted to one (such as a list or tuple of floats or ints, or a float or int, or a numpy ndarray)

Returns

a class that wraps a JAX array (such that it can be added to the JAX Pytree and thus used as an argument to a differentiable function) to be passed to a function wrapped with overload_jax(), while also inheriting from pyadjoint.OverloadedType (since you can’t inherit from a JAX array; see https://github.com/google/jax/issues/4269).

Return type

ndarray

flatten() jax._src.numpy.lax_numpy.ndarray[source]

Returns a flattened 1-d array (NOT an ndarray, but rather the array type it contains)

property shape

The shape of the array

property size

How many elements does the array contain?

tree_flatten() Tuple[Tuple[jax._src.numpy.lax_numpy.ndarray, ...], None][source]

Flattens an ndarray in the JAX Pytree structure

Returns

tuple containing any arrays (or other children) this ndarray holds, and an empty metadata field

Return type

tuple

classmethod tree_unflatten(aux_data, children)[source]

Constructs an ndarray from its flattened components

Parameters
• cls (type) – ndarray

• aux_data (None) – ignore this parameter

• children (Union[jax.interpreters.xla.DeviceArray,Iterable[jax.interpreters.xla.DeviceArray]]) – any children that belonged to this ndarray before it was flattened

Returns

an ndarray holding the children

Return type

ndarray

unwrap(to_jax: bool = True) jax._src.numpy.lax_numpy.ndarray[source]

If this ndarray holds recursively nested ndarrays (e.g. its __repr__() is ndarray(ndarray(…))), unwrap until it holds the array data contained in the deepest-nested ndarray.

This is mostly a utility for use in jacfwd and jacrev in pyadjoint_utils/numpy_adjoint/jax.py

Parameters

to_jax (bool, optional) – go one level further and return the raw JAX array (instead of ndarray, the OverloadedType wrapper)?, defaults to False

Returns

unwrapped version of self

Return type

jax.interpreters.xla.DeviceArray

Converts the input to an ndarray. This function is NOT

overloaded (does not add any Block to the Tape). If you want to convert an AdjFloat to an ndarray or vice-versa, use the functions to_jax() or to_adjfloat() respectively.

Parameters

obj (Union[jax.interpreters.xla.DeviceArray,Iterable[Union[int,float,jax.interpreters.xla.DeviceArray]]]) – the object to wrap; should be either a JAX ndarray or something that can be converted to one (such as a list or tuple of floats or ints, or a float or int, or a numpy ndarray)

Returns

a class that wraps a JAX array (such that it can be added to the JAX Pytree and thus used as an argument to a differentiable function) to be passed to a function wrapped with overload_jax(), while also inheriting from pyadjoint.OverloadedType (since you can’t inherit from a JAX array; see https://github.com/google/jax/issues/4269).

Return type

ndarray

pyadjoint_utils.jax_adjoint.overload_jax(func: pyadjoint_utils.jax_adjoint.jax.Function, static_argnums: Optional[Union[int, Iterable[int]]] = None, argnums: Optional[Union[int, Iterable[int]]] = None, nojit: Optional[bool] = False, function_name: Optional[str] = None, checkpoint: bool = False, concrete: bool = False, backend: Optional[str] = None, donate_argnums: Optional[Union[int, Iterable[int]]] = None, pointwise: Union[bool, Sequence[bool]] = False, out_pointwise: Optional[Union[bool, Sequence[bool]]] = None, compile_jacobian_stack: bool = True, **jax_kwargs) pyadjoint_utils.jax_adjoint.jax.OverloadedFunction[source]

Parameters
• func (Function) – The function to JIT compile and make differentiable

• static_argnums (Union[int,Iterable[int]], optional) – The static_argnums parameter of jax.jit (e.g. numbers of arguments that, if changed, should trigger recompilation)

• argnums (Union[int,Iterable[int]], optional) – The numbers of the arguments you want to differentiate with respect to. For example, if you have a function f(x,p,w) and want the derivative with respect to p and w, pass argnums=(1,2).

• nojit (bool, optional) – If True, do NOT JIT compile the function, defaults to False

• function_name (str, optional) – if you want the function’s name on the JAXBlock recorded as something other than func.__name__, use this parameter

• checkpoint (bool, optional) – if True, make func recompute internal linearization points when differentiated (as opposed to computing these in the forward pass and storing the results). This increases total FLOPs in exchange for less memory usage/fewer acceses, defaults to False

• concrete (bool, optional) – if True, indicates that the function requires value-dependent Python control flow, defaults to False

• backend (str, optional) – String representing the XLA backend to use (e.g. ‘cpu’, ‘gpu’, ‘tpu’). (Note that this is an experimental JAX feature, and its API is likely to change), defaults to None

• donate_argnums (Union[int,Iterable[int]], optional) – Which arguments are ‘donated’ to the computation? In other words, you cannot reuse these arguments after calling the function. This lets XLA more aggresively re-use donated buffers., default None

• pointwise (Union[bool,Sequence[bool]], optional) – By default, this is false. True means the function performs operations on a batch of points. This allows optimizing the Jacobian calculations by only computing the diagonal component. If a list, then there should be a bool for each argnum.

• out_pointwise (Union[bool,Sequence[bool]], optional) – If any inputs are defined pointwise, this specifies which outputs are defined pointwise. If a list, then there should be abool for each output. By default, all outputs will be assumed pointwise if any inputs are pointwise.

• compile_jacobian_stack (bool, optional) – If we need to make a Jacobian matrix corresponding to this function, and that computation is done pointwise, we have to jnp.stack the pointwise Jacobians into one full Jacobian. This parameter controls whether or not we compile that stack operation with jax.jit, which can take several minutes to compile the stack the first time a Jacobian is requested (since it requires unrolling a Python-mode loop over quadrature points), but often leads to factor-of-5-or-more improvements in the total time required to compute the Jacobian for subsequent evaluations. Defaults to True

Returns

A function with the same signature as func that performs the same computation, just with JAX JIT compilation, and being differentiable by both JAX and Pyadjoint

Return type

This method makes several assumptions: 1) The function is explicit, i.e. y = func(x), where y is the output of the operation. 2) All of y is possible to convert to an OverloadedType. 3) Unless annotation is turned off, the operation should always be annotated when calling the overloaded function.

After the overloaded function is called, the pointwise bool for each input is recorded in the function’s pointwise attribute.

Parameters
• func (function) – The target function for which to create an overloaded version.

• pointwise (bool or list[bool]) – True means the function performs operations on a batch of points. This allows optimizing the Jacobian calculations by only computing the diagonal component. If a list, then there should be a bool for each input.

Returns

An overloaded version of func

Return type

function

pointwise (bool or list[bool]) – See overload_autograd().