Source code for crikit.cr.autograd
from .types import PointMap
from .space_builders import DirectSum, enlist
from .numpy import Ndarrays
from .overloaded import overloaded_point_map
from autograd import vector_jacobian_product, jacobian, elementwise_grad
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.numpy.numpy_vspaces import ArrayVSpace
from pyadjoint.enlisting import Enlist
from pyadjoint import AdjFloat
from pyadjoint_utils.numpy_adjoint import ndarray
from pyadjoint_utils.numpy_adjoint.autograd import overload_autograd
from functools import wraps
from itertools import product
import numpy as np
import autograd.numpy as anp
ArrayBox.register(ndarray)
ArrayBox.register(AdjFloat)
ArrayVSpace.register(ndarray)
ArrayVSpace.register(AdjFloat)
from autograd.extend import defvjp
def vjpmaker_trace(ans, x, offset=0, axis1=0, axis2=1):
axes = tuple(range(x.ndim))
eye = anp.eye(x.shape[axis1], x.shape[axis2], k=offset, dtype=bool)
eye = anp.where(eye, anp.ones_like(x), anp.zeros_like(x))
def vjp(g):
g = anp.expand_dims(g, axis=axis1)
g = anp.expand_dims(g, axis=axis2)
return eye * g
return vjp
defvjp(anp.trace, vjpmaker_trace)
[docs]class AutogradPointMap(PointMap):
def __init__(self, source, target, func, bare=False, pointwise=True):
self._orig_func = func
self._bare = bare
# Autograd requires once one simple type (ndarray/stdnumeric) per *args index: fit ag_func to that
if self._bare:
self._ag_func = self._orig_func
else:
@wraps(self._orig_func)
def _expanded_func(*args, **kwargs):
return self._orig_func(args, **kwargs)
self._ag_func = _expanded_func
self._overloaded_func = overload_autograd(self._ag_func, pointwise)
self.pointwise = self._overloaded_func.pointwise
super().__init__(source, target)
def __repr__(self):
return f"AutogradPointMap({self._orig_func})"
def __call__(self, arg, **kwargs):
if isinstance(self.source, DirectSum):
return self._overloaded_func(*arg, **kwargs)
else:
return self._overloaded_func(arg, **kwargs)
[docs]def point_map(source_tuple, target_tuple, **kwargs):
"""This decorator turns the decorated function into an AutogradPointMap.
The given tuples are used to create :class:`~crikit.cr.numpy.Ndarrays` spaces to set as the
source and target spaces.
"""
dtype = kwargs.get("dtype", None)
if source_tuple is None:
source = None
elif any(isinstance(s, tuple) for s in source_tuple):
source = DirectSum(*tuple(Ndarrays(s, dtype) for s in source_tuple))
else:
source = Ndarrays(source_tuple, dtype)
if any(isinstance(t, tuple) for t in target_tuple):
target = DirectSum(*tuple(Ndarrays(t, dtype) for t in target_tuple))
else:
target = Ndarrays(target_tuple, dtype)
def point_map_decorator(func):
ag = AutogradPointMap(source, target, func, **kwargs)
return ag
return point_map_decorator