Source code for pyadjoint_utils.jax_adjoint.array

import jax.numpy as jnp
import numpy as np
import jax
from jax.ops import index_update
from jax.tree_util import (
    register_pytree_node,
    register_pytree_node_class,
    tree_flatten,
    tree_unflatten,
)
from jax.tree_util import Partial as partial
from pyadjoint.overloaded_function import overload_function
from pyadjoint.overloaded_type import (
    OverloadedType,
    register_overloaded_type,
    create_overloaded_object,
)
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape
from pyadjoint_utils.block import Block
from pyadjoint_utils.convert import make_convert_block
from pyadjoint_utils.adjfloat import AdjFloat
from pyadjoint import AdjFloat, Control
from typing import Any, Iterable, Sequence, Union, Optional, Tuple
from .blocks import AddBlock, SubBlock, PowBlock, MulBlock, DivBlock, NegBlock

Array = Any
Shape = Sequence[int]

_default_dtype = jnp.float64


def set_default_dtype(dtype: jnp.dtype) -> None:
    """Sets the default data type for jax arrays used inside crikit. Default jax.numpy.float64.

    :param dtype: the default data type to set
    :type dtype: Union[jnp.dtype,str,Tuple[Union[jnp.dtype,str]]]
    :returns: None

    """
    global _default_dtype
    dtype = jnp.dtype(dtype)
    _default_dtype = dtype


def get_default_dtype() -> jnp.dtype:
    """Returns the default CRIKit JAX dtype

    :returns: CRIKit default JAX dtyle
    :rtype: jnp.dtype

    """
    return _default_dtype


def _reverse(block):
    def _reversed_block(operator, args, output):
        return block(operator, (args[1], args[0]), output)

    return _reversed_block


_default_op_map = {
    "__neg__": (lambda x: -x, NegBlock),
    "__add__": (lambda x, y: x + y, AddBlock),
    "__mul__": (lambda x, y: x * y, MulBlock),
    "__truediv__": (lambda x, y: x / y, DivBlock),
    "__rtruediv__": (lambda x, y: y / x, _reverse(DivBlock)),
    "__sub__": (lambda x, y: x - y, SubBlock),
    "__pow__": (lambda x, y: x ** y, PowBlock),
    "__radd__": (lambda x, y: y + x, _reverse(AddBlock)),
    "__rmul__": (lambda x, y: y * x, _reverse(MulBlock)),
    "__rsub__": (lambda x, y: y - x, _reverse(SubBlock)),
    "__rpow__": (lambda x, y: y ** x, _reverse(PowBlock)),
}
# similar to what pyadjoint defines in adjfloat.py
def annotate_operator(orig_operator, nojit=False, op_map=_default_op_map):
    """Decorates an operator like __add__, __sub__, etc.
    with JAX JIT compilation.
    """

    def is_unitary(op):
        # only one unary operator right now
        return op.__name__ == "__neg__"

    op, block_ctor = op_map[orig_operator.__name__]
    if not nojit:
        op = jax.jit(op)

    def annotated_operator(*args):
        # args[0] is always self
        output = args[0].__class__(op(*(convert_arg(x) for x in args)))
        args = [
            arg if isinstance(arg, OverloadedType) else create_overloaded_object(arg)
            for arg in args
        ]
        if annotate_tape():
            block = block_ctor(op, args, output)
            tape = get_working_tape()
            tape.add_block(block)
            block.add_output(output.block_variable)

        return output

    annotated_operator.__name__ = orig_operator.__name__
    return annotated_operator


[docs]@register_pytree_node_class class ndarray(OverloadedType): def __repr__(self): return f"ndarray({self.arr.__repr__()})" def __array__(self, dtype=None): if dtype: if hasattr(self.arr, "__array__"): return self.arr.__array__(dtype=dtype) else: return np.array(self.arr).__array__(dtype=dtype) else: if hasattr(self.arr, "__array__"): return self.arr.__array__() else: return np.array(self.arr).__array__() def __float__(self): return float(self.unwrap(True)) def __iter__(self): return self.arr.__iter__()
[docs] def __init__(self, obj: Array, *args, **kwargs): """Note: you should not typically use this constructor directly in your code. Instead, you should call :func:`array` or :func:`asarray`, which will call the constructor of this class when appropriate. :param obj: the object to wrap; should be either a JAX ndarray or something that can be converted to one (such as a list or tuple of floats or ints, or a float or int, or a numpy ndarray) :type obj: jax.interpreters.xla.DeviceArray :return: a class that wraps a JAX array (such that it can be added to the JAX Pytree and thus used as an argument to a differentiable function) to be passed to a function wrapped with `overload_jax()`, while also inheriting from pyadjoint.OverloadedType (since you can't inherit from a JAX array; see https://github.com/google/jax/issues/4269). :rtype: ndarray """ dtype = kwargs.get("dtype", np.float64) self.arr = obj self.extras = args super().__init__()
@property def value(self): return self.arr @value.setter def value(self, new_jax_array): self.arr = new_jax_array @property def dtype(self): return self.arr.dtype @property def ndim(self): return self.arr.ndim
[docs] def unwrap(self, to_jax: bool = True) -> jnp.ndarray: """ If this ndarray holds recursively nested ndarrays (e.g. its __repr__() is ndarray(ndarray(...))), unwrap until it holds the array data contained in the deepest-nested ndarray. This is mostly a utility for use in jacfwd and jacrev in pyadjoint_utils/numpy_adjoint/jax.py :param to_jax: go one level further and return the raw JAX array (instead of ndarray, the OverloadedType wrapper)?, defaults to False :type to_jax: bool, optional :return: unwrapped version of self :rtype: jax.interpreters.xla.DeviceArray """ newarr = self.arr while isinstance(newarr, ndarray): newarr = newarr.arr if isinstance(newarr, (list, tuple)): tna = type(newarr) arr = list(newarr) for i, val in enumerate(newarr): if isinstance(val, ndarray): val = val.unwrap(to_jax=to_jax) arr[i] = val return tna(arr) if to_jax: return newarr else: # we went one level too far! # whatever, it's relatively cheap to construct a new ndarray return ndarray(newarr)
[docs] def tree_flatten(self) -> Tuple[Tuple[jnp.ndarray, ...], None]: """ Flattens an ndarray in the JAX Pytree structure :return: tuple containing any arrays (or other children) this ndarray holds, and an empty metadata field :rtype: tuple """ if self.extras: return ((self.arr, *self.extras), None) return ((self.arr,), None)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """ Constructs an ndarray from its flattened components :param cls: ndarray :type cls: type :param aux_data: ignore this parameter :type aux_data: None :param children: any children that belonged to this ndarray before it was flattened :type children: Union[jax.interpreters.xla.DeviceArray,Iterable[jax.interpreters.xla.DeviceArray]] :return: an ndarray holding the children :rtype: ndarray """ return cls(*children)
[docs] def flatten(self) -> jnp.ndarray: """Returns a flattened 1-d array (NOT an ndarray, but rather the array type it contains)""" try: return self.arr.flatten() except Exception: return self.arr
@classmethod def _ad_init_object(cls, obj: Array): obj = jnp.array(obj) if not isinstance(obj, (int, float)) else obj return cls(obj) @property def size(self): """How many elements does the array contain?""" try: return self.arr.size except Exception: return 0 @property def shape(self): """The shape of the array""" try: return self.arr.shape except Exception: return () @property def T(self): """Returns the transpose of this ndarray""" return ndarray(self.arr.T) def __len__(self): return self.arr.__len__() def __eq__(self, other): return self.arr.__eq__(other) # annotated operators are implemented in the table # mapping their name to a tuple of the implementation and # the corresponding Block contained in annotate_operator() @annotate_operator def __add__(self, other): pass @annotate_operator def __neg__(self): pass @annotate_operator def __truediv__(self, other): pass @annotate_operator def __rtruediv__(self, other): pass @annotate_operator def __radd__(self, other): return self.__add__(other) @annotate_operator def __rmul__(self, other): pass def __iadd__(self, other): return NotImplemented def __imul__(self, other): return NotImplemented def __isub__(self, other): return NotImplemented @annotate_operator def __sub__(self, other): pass @annotate_operator def __rsub__(self, other): pass @annotate_operator def __mul__(self, other): pass @annotate_operator def __pow__(self, other): pass @annotate_operator def __rpow__(self, other): pass def __abs__(self): return ndarray(jnp.abs(self.arr)) def _ad_create_checkpoint(self) -> jnp.ndarray: return self.arr def _ad_restore_at_checkpoint(self, checkpoint): return ndarray(checkpoint) def _ad_dim(self) -> int: return self.arr.size def _ad_dot(self, other) -> float: sflat = self.flat() oflat = flatten(other) if sflat.size == 1 or oflat.size == 1: return float(jnp.sum(sflat * oflat)) return float(jnp.dot(sflat, oflat)) def _ad_mul(self, other): return ndarray(self.arr * other) def _ad_add(self, other): return ndarray(self.arr + other) def _ad_copy(self): if isinstance(self.arr, tuple): try: return ndarray(tuple(map(lambda x: x.copy(), self.arr))) except Exception: return self.arr return ( ndarray(self.arr.copy()) if not isinstance(self.arr, (int, float)) else ndarray(self.arr) ) def _ad_convert_type(self, value, options={}): return array(value, **options) def copy(self, *args, **kwargs): return ndarray(self.arr.copy(*args, **kwargs)) def copy_data(self): return ndarray(self.arr.copy()) def flat(self) -> jnp.ndarray: return jnp.ravel(self.unwrap(to_jax=True)) @staticmethod def _ad_assign_numpy(dst, src, offset): if isinstance(src, ndarray): src = src.unwrap(to_jax=True) if isinstance(src, (list, tuple)): src = list(src) while isinstance(src, list) and len(src) == 1: src = src[0] else: for i, s in enumerate(src): if isinstance(s, ndarray): src[i] = s.unwrap(to_jax=True) src_val = src[offset] if isinstance(src_val, (int, float)): dst = src_val else: dst = jnp.reshape(jnp.array(src_val), dst.shape) offset += 1 return array(dst), offset if hasattr(src, "__len__") and len(src) == 1: src = src[0] if isinstance(src, (int, float)): return array(src), offset + 1 dst = jnp.reshape(jnp.array(src[offset : offset + dst.size]), dst.shape) offset += dst.size return array(dst), offset @staticmethod def _ad_to_list(m) -> list: if isinstance(m, ndarray): return np.array(m.arr).flatten().tolist() try: return list(np.array(m).tolist()) except TypeError: return [m] def __getitem__(self, item): annotate = annotate_tape() if annotate: block = JAXArraySliceBlock(self, item) tape = get_working_tape() tape.add_block(block) with stop_annotating(): out = self.arr.__getitem__(item) if annotate: out = ndarray(out) if not isinstance(out, ndarray) else out block.add_output(out.create_block_variable()) return out
[docs]def array(obj: Array, **kwargs) -> ndarray: """Converts the input to an :class:`ndarray`. This function is NOT overloaded (does not add any :class:`Block` to the :class:`Tape`). If you want to convert an :class:`AdjFloat` to an :class:`ndarray` or vice-versa, use the functions :func:`to_jax` or :func:`to_adjfloat` respectively. :param obj: the object to wrap; should be either a JAX ndarray or something that can be converted to one (such as a list or tuple of floats or ints, or a float or int, or a numpy ndarray) :type obj: Union[jax.interpreters.xla.DeviceArray,Iterable[Union[int,float,jax.interpreters.xla.DeviceArray]]] :returns: a class that wraps a JAX array (such that it can be added to the JAX Pytree and thus used as an argument to a differentiable function) to be passed to a function wrapped with `overload_jax()`, while also inheriting from pyadjoint.OverloadedType (since you can't inherit from a JAX array; see https://github.com/google/jax/issues/4269). :rtype: ndarray """ return _backend_array(obj, **kwargs)
def _backend_array(obj, **kwargs): directly_convertible_types = ( float, int, jax.interpreters.xla.DeviceArray, jnp.ndarray, np.ndarray, ndarray, ) if isinstance(obj, ndarray): return obj elif isinstance(obj, directly_convertible_types): return ndarray(obj) dtype = kwargs.get("dtype", _default_dtype) dtype = kwargs.get("dtype", _default_dtype) if isinstance(obj, (list, tuple)): if len(obj) == 0: return ndarray(obj) while isinstance(obj, (list, tuple)) and not isinstance( obj[0], directly_convertible_types ): if len(obj) == 1: obj = obj[0] else: obj = list(obj) for i, val in enumerate(obj): obj[i] = array(val) if isinstance(obj[0], ndarray): return ndarray(obj) elif isinstance(obj, AdjFloat): obj = float(obj) elif isinstance(obj, Control): obj = obj.data() # jax.numpy.array(numpy.array(x)) is typically faster than # jax.numpy.array(x), especially if x is a list or tuple, because # numpy.array() is written in C, whereas jax.numpy.array() is written in # Python. This is recommended by the JAX developers. try: return ndarray(jnp.array(np.array(obj), dtype=dtype), **kwargs) except Exception: return ndarray(obj)
[docs]def asarray(obj: Array, **kwargs) -> ndarray: return array(obj, **kwargs)
ConvertJAXToAdjFloat = make_convert_block( AdjFloat, _backend_array, "ConvertJAXToFloat", ) ConvertAdjFloatToJAX = make_convert_block( _backend_array, AdjFloat, "ConvertAdjFloatToJAX", ) to_adjfloat = overload_function(AdjFloat, ConvertJAXToAdjFloat) to_jax = overload_function(_backend_array, ConvertAdjFloatToJAX) def flatten(x): if isinstance(x, ndarray): return x.flat() elif isinstance(x, (np.ndarray, jax.interpreters.xla.DeviceArray)): return jnp.ravel(x) else: return x class JAXArraySliceBlock(Block): def __init__(self, arr, item): super().__init__() self.add_dependency(arr) self.item = item def evaluate_adj_component( self, inputs, adj_inputs, block_variable, idx, prepared=None ): adj_output = jnp.zeros(inputs[0].shape) return index_update(adj_output, self.item, adj_inputs[0]) def recompute_component(self, inputs, block_variable, idx, prepared): return inputs[0][self.item] def evaluate_tlm_component( self, inputs, tlm_inputs, block_variable, idx, prepared=None ): return tlm_inputs[0][self.item] def convert_arg( x: Union[ndarray, list, tuple, int, float, jnp.ndarray] ) -> Union[list, tuple, int, float, jnp.ndarray]: if isinstance(x, (jnp.ndarray, jax.interpreters.xla.DeviceArray, int, float)): return x elif isinstance(x, ndarray): return x.unwrap(to_jax=True) elif isinstance(x, (list, tuple)): tx = type(x) return tx([convert_arg(v) for v in x]) elif isinstance(x, Control): return convert_arg(x.data()) try: return jnp.array(x) except Exception: return x register_overloaded_type(ndarray, jnp.ndarray) register_overloaded_type(ndarray, jax.interpreters.xla.DeviceArray) try: register_overloaded_type(ndarray, jax.interpreters.xla._DeviceArray) except: pass try: import jaxlib register_overloaded_type(ndarray, jaxlib.xla_extension.Buffer) except: pass