Source code for crikit.cr.jax_utils

import jax.numpy as np
import numpy as onp
import jax

from crikit.covering import Covering, register_covering
from crikit.covering.ufl import get_numpy_shape
from crikit.cr.types import PointMap, Space
from crikit.cr.ufl import UFLExprSpace, UFLFunctionSpace
from crikit.cr.quadrature import get_quadrature_params, make_quadrature_spaces
from crikit.fe import *
from crikit.fe_adjoint import *
from crikit.projection import project
from pyadjoint_utils.jax_adjoint import array, ndarray
from pyadjoint_utils.fenics_adjoint import function_get_local, function_set_local
from pyadjoint_utils import ReducedFunctionNumPy, ReducedFunction
from pyadjoint.tape import (
    no_annotations,
    annotate_tape,
    get_working_tape,
    stop_annotating,
)
from pyadjoint.enlisting import Enlist
from pyadjoint_utils.convert import make_convert_block
from pyadjoint.overloaded_function import overload_function


[docs]class JAXArrays(Space): """ See Ndarrays in crikit/cr/numpy.py Args: shape (Iterable): the shape of the arrays dtype (jax.numpy.dtype): the data type (default None) """ def __init__(self, shape, dtype=None): self._shape = tuple(shape) self._dtype = dtype self._indefinite_axes = np.array(shape) < 0 self._definite_axes = np.logical_not(self._indefinite_axes) self._definite_shape = np.asarray(shape)[self._definite_axes] @property def shape(self): return self._shape def shape(self): return self._shape def is_point(self, point): if not isinstance( point, (np.ndarray, jax.interpreters.xla.DeviceArray, ndarray) ): return False if len(point.shape) != len(self._shape): return False if self._dtype is not None and self._dtype != point.dtype: return False return np.array_equal( np.asarray(point.shape)[self._definite_axes], self._definite_shape ) def point(self, **kwargs): indefinite_ax_size = kwargs.get("indefinite_ax_size", 10) return np.ones(tuple(s if s != -1 else indefinite_ax_size for s in self._shape)) def __eq__(self, other): return ( isinstance(other, JAXArrays) and self._shape == other._shape and self._dtype == other._dtype ) def __repr__(self): if self._dtype is None: return f"JAXArrays({self._shape})" return f"JAXArrays({self._shape}, dtype={self._dtype})"
[docs]class JAX_To_UFLFunctionSpace(PointMap): def __init__( self, source, target, quad_space=None, quad_params=None, make_block=True ): if quad_space is None: quad_space, quad_params = make_quadrature_spaces( target, quad_params=quad_params ) self._quad_space = quad_space self._dx = dx(metadata=quad_params) if isinstance(target, UFLFunctionSpace): self._target_space = target._functionspace else: self._target_space = quad_space if source is None: quad_shape = get_numpy_shape(self._quad_space) source = JAXArrays(quad_shape) super().__init__(source, target) def __call__(self, arr, **kwargs): q = function_set_local(Function(self._quad_space), arr) if self._target_space != self._quad_space: q = project(q, self._target_space, dx=self._dx) return q
[docs]class UFLExprSpace_To_JAX(PointMap): """Much like UFLExprSpace_To_Numpy, this class maps a UFL expression into a JAX array. The constructor inputs are the same here, replacing Ndarrays with JAXArrays Args: source (UFLExprSpace or UFLFunctionSpace): the UFL space to use as input. target (JAXArrays, optional): the target space to map to. quad_space (optional): the finite-element quadrature space to interpolate to. quad_params (dict, optional): parameters for the quadrature space. domain (optional): the UFL domain for the quadrature space. """ def __init__( self, source, target=None, quad_space=None, quad_params=None, domain=None ): if quad_space is None: quad_space, quad_params = make_quadrature_spaces( source, quad_params=quad_params, domain=domain ) self._quad_space = quad_space self._dx = dx(metadata=quad_params) if target is None: quad_shape = get_numpy_shape(self._quad_space) target = JAXArrays(quad_shape) super().__init__(source, target) def __call__(self, expr, **kwargs): if expr.ufl_shape != self._quad_space.ufl_element().value_shape(): raise ValueError( f"Expression shape {expr.ufl_shape} does not match the target shape of {self._quad_space.ufl_element().value_shape()}" ) if isinstance(expr, Function) and self._quad_space == expr.function_space(): func = expr else: func = project(expr, self._quad_space, dx=self._dx) res = function_get_local(func) return convert_numpy_to_jax(res)
def backend_convert_numpy_to_jax(x): return np.array(x) def backend_convert_jax_to_numpy(x): return onp.array(x) ConvertNumpyToJAX = make_convert_block( backend_convert_numpy_to_jax, backend_convert_jax_to_numpy, "ConvertNumpyToJAX" ) ConvertJAXToNumpy = make_convert_block( backend_convert_jax_to_numpy, backend_convert_numpy_to_jax, "ConvertJAXToNumpy" ) convert_numpy_to_jax = overload_function( backend_convert_numpy_to_jax, ConvertNumpyToJAX ) convert_jax_to_numpy = overload_function( backend_convert_jax_to_numpy, ConvertJAXToNumpy )
[docs]@register_covering(UFLExprSpace, JAXArrays) @register_covering(UFLFunctionSpace, JAXArrays) class JAX_UFLFunctionSpace_Covering(Covering): def __init__( self, base_space, covering_space=None, domain=None, quad_params=None, **kwargs ): quad_space, quad_params = make_quadrature_spaces( base_space, quad_params=quad_params, domain=domain ) self._quad_space = quad_space self._quad_params = quad_params if covering_space is None: quad_shape = get_numpy_shape(self._quad_space) covering_space = JAXArrays(quad_shape) super().__init__(base_space, covering_space, **kwargs) def covering_map(self): return JAX_To_UFLFunctionSpace( self._covering_space, self._base_space, quad_space=self._quad_space, quad_params=self._quad_params, ) def section_map(self): return UFLExprSpace_To_JAX( self._base_space, self._covering_space, quad_space=self._quad_space, quad_params=self._quad_params, )
[docs]class ReducedFunctionJAX(ReducedFunctionNumPy): """This class implements the ReducedFunction for a given ReducedFunction and controls with JAX data structures. Like a ReducedFunctionNumPy, these are created from ReducedFunction instances like >>> from pyadjoint_utils import overload_jax >>> f = overload_jax(lambda x: np.sum(x ** 2)) >>> x = array(np.array([1.0,2.0])) >>> rf = ReducedFunction(f(x),Control(x)) >>> rf_jax = ReducedFunctionJAX(rf) >>> float(rf_jax(x)) 5.0 """ def __init__(self, rf): if not isinstance(rf, ReducedFunction): raise TypeError( f"Must pass a ReducedFunction to the ReducedFunctionJAX constructor, not a {type(rf)}!" ) self.rf = rf def __call__(self, val): ip = self.get_rf_input(array(val)) output = self.rf(ip) return output def get_global(self, controls): ctrls = [] for i, val in enumerate(controls): if isinstance(val, Control): ctrls += val.fetch_numpy(val.control) elif hasattr(val, "_ad_to_list"): ctrls += val._ad_to_list(val) else: ctrls += self.controls[i].control._ad_to_list(val) return ndarray(np.array(ctrls)) @no_annotations def jac_action(self, val): ip = self.get_rf_input(val) dJdp = self.rf.jac_action(ip) return self.get_outputs_array(dJdp) def get_outputs_array(self, vals): outs = [] vals = Enlist(vals) for i, out in enumerate(self.outputs): if vals[i] is not None: outs += out.output._ad_to_list(vals[i]) else: outs += [0] * out.output._ad_dim() return ndarray(np.array(outs))
class JAXFunctionJITTracer: def __init__(self, f, strict=None): self.f = f self.strict = strict def trace(self, f, *args, **kwargs): return jax.jit(f, *args, **kwargs) def trace_and_overload(self, *args, jit=True, strict=False, check=False, **kwargs): # technically this function does the tracing when later called but # given how this class will be used (as internals in the CR class), # that is actually fine since it will be called immediately upon # return from this function return overload_jax(self.f, jit=jit, **kwargs)