Source code for crikit.cr.numpy
from .types import Space, PointMap
from .space_builders import DirectSum
import numpy as np
from jax import numpy as jnp
from jax.interpreters.xla import DeviceArray
[docs]class Ndarrays(Space):
"""This class represents a Space of NumPy arrays of a given shape and
optionally a specific data type.
Negative numbers can be used in the shape to indicate that the length of
that dimension doesn't matter as long as that dimension exists. For example,
if the given shape is (-1, 2, 5), then arrays with shapes (4, 2, 5) and (1, 2, 5)
are both points in the space, but (2, 5) is not.
>>> import numpy as np
>>> from crikit.cr.numpy import Ndarrays
>>> space = Ndarrays((-1, 2, 5))
>>> space.is_point(np.zeros((4, 2, 5)))
True
>>> space.is_point(np.zeros((1, 2, 5)))
True
>>> space.is_point(np.zeros((2, 5)))
False
Args:
shape (tuple or list): The shape of the arrays.
dtype (numpy.dtype): A NumPy datatype that further constrains the Space.
If None, Ndarrays space will not have a specific type.
"""
def __init__(self, shape, dtype=None):
self._shape = shape
self._dtype = dtype
self._indefinite_axes = np.array(shape) < 0
self._definite_axes = np.logical_not(self._indefinite_axes)
self._definite_shape = np.asarray(shape)[self._definite_axes]
def shape(self):
return self._shape
[docs] def is_point(self, point):
"""Returns true if the given point is an ndarray and its shape and dtype
match those of the space.
"""
if not isinstance(point, (np.ndarray, jnp.ndarray, DeviceArray)):
return False
if len(point.shape) != len(self._shape):
return False
# Make sure that the shapes are the same along each definite axis
point_definite_shape = np.asarray(point.shape)[self._definite_axes]
return np.array_equal(point_definite_shape, self._definite_shape) and (
self._dtype is None or point.dtype == self._dtype
)
def point(self, **kwargs):
shape = self._shape
if "near" in kwargs:
near = kwargs["near"]
assert self.is_point(near)
shape = near.shape
return np.zeros(shape)
def __eq__(self, other):
return (
isinstance(other, Ndarrays)
and self._shape == other._shape
and self._dtype == other._dtype
)
def __repr__(self):
if self._dtype is None:
return f"Ndarrays({self._shape})"
return f"Ndarrays({self._shape}, dtype={self._dtype})"
[docs]class CR_P_LaplacianNumpy(PointMap):
def __init__(self, p=2, dim=2, input_u=True):
self._p = p
self._input_u = input_u
np_vec_space = Ndarrays((-1, dim))
if self._input_u:
np_scalar_space = Ndarrays((-1,))
source = DirectSum(np_scalar_space, np_vec_space)
else:
source = np_vec_space
super(CR_P_LaplacianNumpy, self).__init__(source, np_vec_space)
def __call__(self, args, **kwargs):
if self._input_u:
gradu = args[1]
else:
gradu = args
mu = (np.sum(gradu * gradu, axis=1) + 1e-12) ** ((self._p - 2) / 2)
mu = mu[:, None]
out = mu * gradu
return out
def setParams(self, p):
self._p = p