crikit API reference

Core classes

class crikit.cr.types.Space[source]

Arguments / output spaces of CRs.

A class of spaces is a representation.

Spaces in a representation may look something like expression trees of operations that build new spaces recursively from base spaces.

Abstract representations will have spaces that can’t instantiate points as real data. They have to be lowered into concrete representations.

is_point(x: Any) bool[source]

test if a point is in the space

point(*args, **kwargs)[source]

create a point in the space

shape() Tuple[int, ...][source]

Get the shape of the space (as in the shape of the arguments that you would make from a point in the space)

class crikit.cr.types.PointMap(source_space, target_space)[source]

Takes points from one space and maps them into another

__call__(point, **kwargs)[source]

Apply the point map to the given point. I.e., map the given point in the source space to a point in the target space.

property source

the input Space of the PointMap

property target

the output Space of the PointMap

Space builders

class crikit.cr.space_builders.DirectSum(*spaces)[source]

Bases: crikit.cr.types.Space

This space represents a concatenation of separate spaces.

__getitem__(idx)[source]

Can be used to iterate through the spaces.

__len__()[source]

Returns the number of component spaces

class crikit.cr.space_builders.Multiset(space, n)[source]

Bases: crikit.cr.types.Space

This space represents a single space repeated multiple times.

__getitem__(idx)[source]

Can be used to iterate through the spaces.

__len__()[source]

Returns the number of component spaces

property space

the base space for the Multiset

crikit.cr.space_builders.enlist(spaces)[source]

Given a space or a tuple of spaces, this returns a DirectSum representing those spaces. If the given space is already a DirectSum, it returns that space. Otherwise, it creates a DirectSum of the space(s) and marks it as enlisted so that it can delisted later.

Parameters

spaces (Space, tuple[Space], or list[Space]) – the space(s) to enlist

Returns

a single space containing the given space(s)

Return type

DirectSum

Point map builders

class crikit.cr.map_builders.Callable(source_space, target_space, callble, bare=False)[source]

Bases: crikit.cr.types.PointMap

This class wraps a Python callable into a PointMap.

I.e., if you (1) have a function or a class/object with a __call__ method, and you (2) know the input space and output space of the callable, then this will put the callable into the Space/PointMap structure.

Set bare to true if the callable requires that the point be unpacked.

A point map that accepts multiple inputs will receive those inputs as a single tuple, so if your callable expects to receive them as separate arguments, you must use bare=True.

>>> from crikit.cr.map_builders import Callable
>>> source = ... # The source and target space don't really matter here.
>>> target = ...
>>> standard_callable = lambda x: x[0] + x[1] + x[2]
>>> c = Callable(source, target, standard_callable)
>>> c((1, 2, 3))
6
>>> bare_callable = lambda a, b, c: a + b + c
>>> c_bare = Callable(source, target, bare_callable, bare=True)
>>> c_bare((1, 2, 3))
6
>>> c_bad = Callable(source, target, bare_callable, bare=False)
>>> c_bad((1, 2, 3))
Traceback (most recent call last):
...
TypeError: <lambda>() missing 2 required positional arguments: 'b' and 'c'
>>> c_bad(1, 2, 3)
Traceback (most recent call last):
...
TypeError: __call__() takes 2 positional arguments but 4 were given
__call__(point, **kwargs)[source]

Calls the callable with given point, unpacking the point first if initialized with bare=True

property callable

the callable that this point map was initialized with

class crikit.cr.map_builders.Parametric(orig_map: crikit.cr.types.PointMap, param_indices: Union[int, Iterable[int]], param_point: crikit.cr.map_builders.InputPoint, bare: Optional[bool] = False, bare_map: Optional[bool] = False)[source]

Bases: crikit.cr.types.PointMap

This class wraps a PointMap so that some parameters do not need to be specified when the point map is called.

In mathematical notation, this wraps a function \(f(u, p)\) so that it can be called as \(f(u)\) or as \(f(u; p)\) with some default value of p if p is not specified.

Parameters
  • orig_map (PointMap) – The point map to wrap.

  • param_indices (int, tuple[int], or list[int]) – The position of the parameters in the input space of the point map.

  • param_point – The default values to use for the parameters if they are not specified in the __call__() method.

  • bare (bool) – If you set this to true, then the number of remaining args to the point map (after removing the param_indices) must be 1, in which case the resulting Parametric map can be called as pmap(val) instead of pmap((val,)).

  • bare_map (bool) – If true, then the args to the point map will be unpacked before calling the map. Note: We should probably get rid of this since we’re now requiring that PointMaps not accept bare arguments (I.e., all PointMap __call__ functions must accept a single arg, which is either a single value or a tuple of values if necessary).

__call__(point: crikit.cr.map_builders.InputPoint, params: Optional[crikit.cr.map_builders.ParameterPoint] = None, **kwargs) crikit.cr.map_builders.OutputPoint[source]
Parameters
  • point – point at which to evaluate the point map

  • params – parameter values to use for the point map. If None, then the default param_point will be used.

  • **kwargs – keyword arguments are passed through to the base point map.

Returns

The output of the base point map.

set_param_point(param_point: crikit.cr.map_builders.ParameterPoint) None[source]

Sets the default value to use for the parameters.

class crikit.cr.map_builders.AugmentPointMap(point_map, param_names, param_space, bare=False)[source]

Bases: crikit.cr.types.PointMap

This class wraps a PointMap so that keyword arguments of its __call__ method are moved to the explicit input space of the map.

In mathematical notation, this wraps a function \(f(u; p)\) so that it must be called as \(f(u, p)\).

If the original point map is called as point_map(point), then the augmented point map is called as aug_map((point, params)) (or aug_map((point, *params) if bare=True).

Parameters
  • point_map (PointMap) – The point map to wrap.

  • param_names (str, tuple[str], or list[str]) – The keywords that will be added to the input space of the point map.

  • param_space (Space, tuple[Space], or list[Space]) – The spaces corresponding to each keyword in the param_names list.

  • bare (bool) – If you set this to true, then the __call__ method expects the params to be unpacked.

__call__(point_params, **kwargs)[source]
Parameters

point_params – the point at which to evaluate the point map and the params to pass as keyword args. It must be in the form (point, params) (or (point, *params) if bare=True was passed in the constructor).

Returns

The output of the base point map.

class crikit.cr.map_builders.CompositePointMap(*point_maps)[source]

Bases: crikit.cr.types.PointMap

This class is a point map that links a group of point maps.

Given a list of point maps, the source space is the source space of the first point map, and the target space is the target space of the last point map.

The __call__ function gives the inputs to the first map, and then successively feeds the outputs of one map to the next map.

point_maps() List[crikit.cr.types.PointMap][source]

Returns a list of the point maps used to create this composite map.

class crikit.cr.map_builders.ParallelPointMap(*point_maps)[source]

Bases: crikit.cr.types.PointMap

This class is a point map that runs a group of point maps independently.

Given a list of point maps, the source space is a DirectSum of the source space from each point map, and similarly for the target space.

The __call__ function gives each argument to the corresponding point map, and concatenates the outputs into a single list.

class crikit.cr.map_builders.IdentityPointMap(source_space, target_space)[source]

Bases: crikit.cr.types.PointMap

This class is a point map that simply returns its input.

CR Implementations

crikit.cr.cr

class crikit.cr.cr.CR(output_type: crikit.invariants.invariants.TensorType, input_types: Sequence[crikit.invariants.invariants.TensorType], cr_function: Optional[Callable] = None, params: Optional[Sequence[Any]] = None, cr_static_argnums: Optional[Sequence[int]] = None, vmap: bool = True, vmap_inner: Optional[bool] = None, nojit: bool = False, strain_energy: bool = False, compiled_jacobian: bool = True, **cr_jax_kwargs)[source]

Bases: crikit.cr.types.PointMap

A Constitutive Relation that automatically generates scalar and form invariants with crikit.invariants. All you need to provide is a function to compute scalar values of the scalar invariants that can be multiplied against the form invariants to form an equivariant tensor function in accordance with the canonical representation of Wineman and Pipkin, who showed that any equivariant (under a physical group) tensor function can be represented as a linear combination of scalar functions of scalar invariants and form invariants. In other words, this is a function that takes in the scalar invariants at a point as a one-dimensional JAX array (as well as any parameters you specify), and returns a one-dimensional JAX array, with one element for each form invariant.

You can use cr_function_shape() to determine how many scalar invariants your function will take in and how many scalar values your function will need to output.

__call__(inputs, **kwargs) Union[pyadjoint_utils.jax_adjoint.array.ndarray, Tuple[pyadjoint_utils.jax_adjoint.array.ndarray]][source]

Evaluates the CR

Parameters

inputs (Union[Iterable[pyadjoint_utils.jax_adjoint.ndarray,jnp.ndarray]]) – the inputs to the CR, as JAX arrays, or pyadjoint_utils.jax_adjoint.ndarray s (if you’re differentiating with Pyadjoint)

Returns

The value of the invariant CR function (self.function) evaluated with the scalar and form-invariants generated by inputs

Return type

Union[ndarray, Tuple[ndarray]]

__init__(output_type: crikit.invariants.invariants.TensorType, input_types: Sequence[crikit.invariants.invariants.TensorType], cr_function: Optional[Callable] = None, params: Optional[Sequence[Any]] = None, cr_static_argnums: Optional[Sequence[int]] = None, vmap: bool = True, vmap_inner: Optional[bool] = None, nojit: bool = False, strain_energy: bool = False, compiled_jacobian: bool = True, **cr_jax_kwargs)[source]

Constructor for CR

Parameters
  • output_type (TensorType) – a TensorType corresponding to the output. If you want a strain-energy CR (one which computes the stress as the partial derivative of a strain energy functional with respect to the first input, then pass crikit.invariants.TensorType.make_scalar() as the output type (i.e. a scalar).

  • input_types (Sequence[TensorType]) – a sequence of TensorTypes corresponding to the inputs

  • cr_function (Callable, optional) – The function to evaluate.

  • params (Sequence[jnp.ndarray], optional) – the initial values of the parameters, default None

  • cr_static_argnums (Union[int,Iterable[int]], optional) – the static_argnums parameter for jax.jit() for your cr_function

  • vmap (bool, optional) – should we jax.vmap() the CR invariant functions over the inputs? True if your CR is going to be given input values at multiple points on a mesh (i.e. if the input is one second-order tensor in 3-d, and you plan to evaluate the CR at multiple points at once by stacking the inputs, you want this to be True), default True

  • vmap_inner (bool, optional) – should we jax.vmap() the inner function over the inputs? True if your inner CR is going to be given input values at multiple points on a mesh that are handled independently of each other This defaults to the value of the vmap parameter.

  • nojit (bool, optional) – if True, do NOT jit-compile the CR function, defaults to False

  • strain_energy (bool, optional) – if True, implies that this CR has a strain energy function – that is, cr_function is a scalar function that gives the strain energy as a function of a symmetric second-order input (and possibly other inputs), and the CR computes the stress as the derivative of strain energy with respect to the symmetric second-order input. Defaults to False

  • compiled_jacobian (bool, optional) – if True, compiles the function that stacks the pointwise Jacobians in the backend Jacobian-computing function in JAXBlock. Identical to the compile_jacobian_stack parameter to pyadjoint_utils.jax_adjoint.overload_jax(). Defaults to True

Returns

a CR object

Return type

CR

property cr_input_shape

The shape of the array of scalar invariants that the CR function takes as its first parameter.

property form_invariant_shape

The shape of the array of form invariants

form_invariants(*inputs)[source]

Computes form invariants given inputs

Parameters

*inputs (Iterable[jnp.ndarray]) – the inputs to the CR

Returns

A JAX DeviceArray containing the stacked form-invariants

Return type

jnp.ndarray

static from_arrays(example_output: jax._src.numpy.lax_numpy.ndarray, example_inputs: Iterable[jax._src.numpy.lax_numpy.ndarray], cr_function: Optional[Callable] = None, params: Optional[Iterable[Any]] = None, cr_static_argnums: Optional[Sequence[int]] = None, vmap: bool = True, **kwargs)[source]

The preferred way to construct a crikit.cr.CR if you don’t want to manually construct the crikit.invariants.TensorType s corresponding to your input and outputs tensor types. Ensure that, if your material has a structural tensor, you include it in example_inputs For example, a plank of wood is frequently modeled as being transverse-isotropic, with the structural tensor being a vector field pointing in the direction of the grain. If you want the symmetry to not include flips–that is, a subset of hemitropy instead of isotropy–ensure that you pass the Levi-Civita tensor (eps_ij or eps_ijk, depending on how many spatial dimensions you’re in) as an example_input, but DO NOT pass it into CR.__call__(). If you pass the Levi-Civita tensor as an example_input, we will account for its presence in the inputs without you passing it in.

Parameters
  • example_output (jnp.ndarray) – an example of what the output of the CR might look like; if that’s a symmetric rank-two tensor, then example_output should also be that (e.g. jnp.eye(number_of_spatial_dimensions)), etc.

  • example_inputs (Sequence[Array]) – an iterable of JAX arrays of the same shape and symmetry as the inputs to the CR function

  • params (Iterable[jnp.ndarray], optional) – the initial values of the parameters, default None

  • cr_static_argnums (Union[int,Iterable[int]], optional) – the static_argnums parameter for jax.jit() for your cr_function

  • vmap (bool, optional) – should we jax.vmap() the CR function over the inputs? True if your CR is going to be given input values at multiple points on a mesh (i.e. if the input is one second-order tensor in 3-d, and you plan to evaluate the CR at multiple points at once by stacking the inputs, you want this to be True), default True

Returns

A crikit.cr.CR

Return type

CR

get_point_maps()[source]

This method returns a PointMap for each of the four functions used to the compute the CR output.

The CR.__call__() method takes inputs (and optionally params as a keyword arg) and uses four separate functions to compute the CR output:

  1. The scalar invariant function computes the scalar invariants as a function of inputs.

  2. The form invariant function computes the form-invariant basis as a function of inputs.

  3. The inner function computes the basis coefficients as a function of the scalar invariants and params.

  4. The coefficient form function computes the CR output using the basis coefficients and the form-invariant basis.

invariant_descriptions(ipython: Optional[bool] = None, html: Optional[bool] = None) str[source]

A string describing both the scalar and form invariant functions, including their indices in the input/output of the CR.

Parameters
  • ipython (bool, optional) – Are you in IPython mode? (e.g. in a Jupyter notebook) By default, tries to guess whether or not you are in IPython mode; set this manually if the behavior is not as desired.

  • html – Return an HTML string instead of a plain-text string? defaults to None, unless ipython is True, then True

Returns

A string describing the invariants

Return type

str

property num_scalar_functions

The number of scalar functions we need to make (each taking in the scalar invariants) in order to right-multiply the row vector of them against the form invariants For example, in 3d, an O(3)-invariant function of a symmetric rank-two tensor and a vector that outputs a symmetric rank-two has a _form_invt_shape of (6,3,3), so we need 6 scalar functions to make the right row vector to get a result of shape (3,3)

save_model(directory)[source]

Save the internal function (the one you can pass as cr_function in CR.__init__() ) of a JAX-based CR to a directory by converting it to a TensorFlow model and then saving that. You can recover the model by using CR.load_tensorflow_model().

Parameters

directory (str) – The directory name to save it to

Returns

None

scalar_invariants(*inputs) jax._src.numpy.lax_numpy.ndarray[source]

Computes scalar invariants given inputs

Parameters

*inputs (Iterable[jnp.ndarray]) – the inputs to the CR

Returns

A JAX DeviceArray containing the scalar invariants

Return type

jnp.ndarray

crikit.cr.cr.cr_function_shape(output: Union[Any, crikit.invariants.invariants.TensorType], inputs: Union[Sequence[crikit.invariants.invariants.TensorType], Sequence[Any]]) Tuple[int, int][source]

Computes the number of scalar invariants that a CR function for given inputs and outputs must take, as well as the number of scalar values that function must output to generate an invariant CR, and returns a tuple of (num_scalar_invariants,num_output_scalar_values).

Parameters
  • output (Union[Array,TensorType]) – either an array (Numpy or JAX) or a TensorType representing the correct shape and symmetry of an output tensor from this CR.

  • inputs (Union[Sequence[TensorType], Sequence[Array]]) – an Iterable of either TensorType instances or arrays of the correct shape and symmetry as the input tensors of this CR; must contain the same type as output (i.e. if output is a TensorType, inputs must contain only TensorTypes, and likewise if output is an array, inputs must only contain arrays.

Returns

A tuple of (number of scalar invariants, number of output scalar values)

Return type

tuple

crikit.cr.cr.save_jax_cr(cr: crikit.cr.cr.CR, directory: str)[source]

Save a JAX-based CR to a directory by converting it to a TensorFlow model and then saving that. You can recover the model by using CR.load_tensorflow_model().

Parameters
  • cr (CR) – The CR to save

  • directory (str) – The directory name to save it to

Returns

None

class crikit.cr.cr.RivlinModel(C, D=None, spatial_dims=3, vmap=True, optimize_d=False)[source]

Bases: crikit.cr.cr.CR

A CR that represents a Rivlin model – that is, one of the form \(W = \sum\limits_{i=0}^n\sum\limits_{j=0}^n C_{ij} (I_1 - 3)^i (I_2 - 3)^j + \sum\limits_{k=1}^m D_k (J - 1)^{2k}\), where \(J = \mathrm{det}(B)\).

__init__(C, D=None, spatial_dims=3, vmap=True, optimize_d=False)[source]
Parameters
  • C (ndarray) – The material constants \(C_{ij}\)

  • D (ndarray, optional) – The material constants \(D_k\), defaults to None

  • spatial_dims (int, optional) – how many spatial dimensions? defaults to 3

  • vmap (bool, optional) – the vmap parameter of CR.__init__() , defaults to True

  • optimize_d (bool, optional) – Controls which parameter we’re optimizing the CR with respect to. If True, optimize D, else optimize C. Defaults to False.

Returns

a RivlinModel

Return type

RivlinModel

class crikit.cr.jax_utils.JAXArrays(shape, dtype=None)[source]

Bases: crikit.cr.types.Space

See Ndarrays in crikit/cr/numpy.py

Parameters
  • shape (Iterable) – the shape of the arrays

  • dtype (jax.numpy.dtype) – the data type (default None)

class crikit.cr.jax_utils.JAX_To_UFLFunctionSpace(source, target, quad_space=None, quad_params=None, make_block=True)[source]

Bases: crikit.cr.types.PointMap

class crikit.cr.jax_utils.UFLExprSpace_To_JAX(source, target=None, quad_space=None, quad_params=None, domain=None)[source]

Bases: crikit.cr.types.PointMap

Much like UFLExprSpace_To_Numpy, this class maps a UFL expression into a JAX array. The constructor inputs are the same here, replacing Ndarrays with JAXArrays :param source: the UFL space to use as input. :type source: UFLExprSpace or UFLFunctionSpace :param target: the target space to map to. :type target: JAXArrays, optional :param quad_space: the finite-element quadrature space to interpolate to. :type quad_space: optional :param quad_params: parameters for the quadrature space. :type quad_params: dict, optional :param domain: the UFL domain for the quadrature space. :type domain: optional

class crikit.cr.jax_utils.JAX_UFLFunctionSpace_Covering(base_space, covering_space=None, domain=None, quad_params=None, **kwargs)[source]

Bases: crikit.covering.covering.Covering

class crikit.cr.jax_utils.ReducedFunctionJAX(rf)[source]

Bases: pyadjoint_utils.reduced_function_numpy.ReducedFunctionNumPy

This class implements the ReducedFunction for a given ReducedFunction and controls with JAX data structures. Like a ReducedFunctionNumPy, these are created from ReducedFunction instances like >>> from pyadjoint_utils import overload_jax >>> f = overload_jax(lambda x: np.sum(x ** 2)) >>> x = array(np.array([1.0,2.0])) >>> rf = ReducedFunction(f(x),Control(x)) >>> rf_jax = ReducedFunctionJAX(rf) >>> float(rf_jax(x)) 5.0

crikit.cr.numpy

class crikit.cr.numpy.Ndarrays(shape, dtype=None)[source]

Bases: crikit.cr.types.Space

This class represents a Space of NumPy arrays of a given shape and optionally a specific data type.

Negative numbers can be used in the shape to indicate that the length of that dimension doesn’t matter as long as that dimension exists. For example, if the given shape is (-1, 2, 5), then arrays with shapes (4, 2, 5) and (1, 2, 5) are both points in the space, but (2, 5) is not.

>>> import numpy as np
>>> from crikit.cr.numpy import Ndarrays
>>> space = Ndarrays((-1, 2, 5))
>>> space.is_point(np.zeros((4, 2, 5)))
True
>>> space.is_point(np.zeros((1, 2, 5)))
True
>>> space.is_point(np.zeros((2, 5)))
False
Parameters
  • shape (tuple or list) – The shape of the arrays.

  • dtype (numpy.dtype) – A NumPy datatype that further constrains the Space. If None, Ndarrays space will not have a specific type.

is_point(point)[source]

Returns true if the given point is an ndarray and its shape and dtype match those of the space.

class crikit.cr.numpy.CR_P_LaplacianNumpy(p=2, dim=2, input_u=True)[source]

Bases: crikit.cr.types.PointMap

crikit.cr.autograd

class crikit.cr.autograd.AutogradPointMap(source, target, func, bare=False, pointwise=True)[source]

Bases: crikit.cr.types.PointMap

crikit.cr.autograd.point_map(source_tuple, target_tuple, **kwargs)[source]

This decorator turns the decorated function into an AutogradPointMap.

The given tuples are used to create Ndarrays spaces to set as the source and target spaces.

crikit.cr.ufl

class crikit.cr.ufl.UFLFunctionSpace(functionspace)[source]

Bases: crikit.cr.types.Space

Represents a UFL FunctionSpace.

__init__(functionspace)[source]
Parameters

functionspace (ufl.FunctionSpace) – the function space to wrap.

class crikit.cr.ufl.UFLExprSpace(expr, tlm_shape=None, ufl_domains=None)[source]

Bases: crikit.cr.types.Space

Represents a space of UFL expression of a certain shape, defined by an example expression. Any UFL expression of the same shape lies in this space.

__init__(expr, tlm_shape=None, ufl_domains=None)[source]
Parameters
  • expr (ufl.core.expr.Expr) – a UFL expression defining the space

  • tlm_shape (tuple, optional) – used to explicitly specify the shape of the space instead of getting it from the expression.

point(domain_idx=None)[source]

Returns the expression that this space was initialized with.

class crikit.cr.ufl.CR_UFL_Expr(arg_space, exprs, pos_map, domain=None)[source]

Bases: crikit.cr.types.PointMap

__init__(arg_space, exprs, pos_map, domain=None)[source]

Take an expression for each component of the CR.

Provide map for how the expression coefficients map to CR arguments.

crikit.cr.ufl.point_map(source_tuple, bare=False, domain=None, **kwargs)[source]

This decorator turns the decorated function into a CR_UFL_Expr.

The given tuples are used to create UFLExprSpace spaces to set as the source and target spaces.

The decorated function will be run once to create the UFL expression for the point map.

class crikit.cr.ufl.CR_P_Laplacian(p=2, dim=2, input_u=True, domain=None)[source]

Bases: crikit.cr.ufl.CR_UFL_Expr

This CR is applied directly to ufl expressions

crikit.cr.stdnumeric

crikit.cr.stdnumeric.ZZ = <crikit.cr.stdnumeric.Integers object>

Represents the space of integers

crikit.cr.stdnumeric.RR = <crikit.cr.stdnumeric.Reals object>

Represents the space of real numbers

crikit.cr.stdnumeric.CC = <crikit.cr.stdnumeric.Complexs object>

Represents the space of complex numbers

crikit.cr.stdnumeric.type_tuple_to_space(tt)[source]

Converts the given numeric type or tuple of types to a Space using the stdnumeric spaces and DirectSum.

If tt is a tuple, type_tuple_to_space() is recursively called on each element of the tuple, and the DirectSum of the result is returned.

If tt is a numeric type, then the corresponding stdnumeric Space is returned. If tt is a subclass of a numeric type, the name of the returned space will contain the name of that class.

Supported numeric types are int, float, complex, or subclasses of those types.

Parameters

tt – a numeric type or arbitrarily nested tuples of numeric types.

Returns

the stdnumeric Space corresponding to the given types.

Return type

Space

crikit.cr.stdnumeric.point_map(source_types, target_types, **kwargs)[source]

Decorates a function to make it a point map (by constructing a Callable instance).

Parameters
  • source_types – a type tuple representing the source space.

  • target_types – a type tuple representing the target space.

  • **kwargs – passed through to the Callable constructor.

Here’s example usage for calculating the p-norm of a two-dimensional vector:

from crikit.cr.stdnumeric import point_map

@point_map(((float, float), float), float, bare=True)
def pnorm_2d(v, p):
    return (v[0]**p + v[1]**p) ** (1/p)
assert pnorm_2d(((1, 2), 1)) == 3

from crikit.cr.types import PointMap
from crikit.cr.space_builders import DirectSum
from crikit.cr.stdnumeric import RR

assert isinstance(pnorm_2d, PointMap)
assert pnorm_2d.source == DirectSum(DirectSum(RR, RR), RR)
assert pnorm_2d.target == RR

Covering

class crikit.covering.covering.Covering(base_space, covering_space, **covering_params)[source]

Base class for Covering types.

To be compatible with the covering params structure, the constructor should accept all keywords arguments and just ignore the ones that it doesn’t need.

covering_map(**params) crikit.cr.types.PointMap[source]

This method must be overridden.

Should return a map from the covering space to the base space.

Parameters

**params – any parameters that should be used for creating the map.

Returns

A map from the covering space to the base space.

Return type

PointMap

section_map(**params) crikit.cr.types.PointMap[source]

This method must be overridden.

Should return a map from the base space to the covering space.

Parameters

**params – any parameters that should be used for creating the map.

Returns

A map from the base space to the covering space.

Return type

PointMap

crikit.covering.covering.get_default_covering_params() dict[source]
Returns

a reference to the current covering params dictionary

Return type

dict

crikit.covering.covering.set_default_covering_params(*args, **kwargs) None[source]

Update the covering params dictionary.

Parameters
  • *args – dictionaries with key-value pairs to add to the covering params dictionary.

  • **kwargs – key-value pairs to add to the covering params dictionary.

crikit.covering.covering.reset_default_covering_params() None[source]

Resets the default covering params to an empty dictionary.

crikit.covering.covering.register_covering(base_space: crikit.cr.types.Space, covering_space: crikit.cr.types.Space, covering_class: Optional[crikit.covering.covering.Covering] = None) crikit.covering.covering.Covering[source]

Register a covering class for use in get_map()

The covering_class should be defined with a covering_map method that returns a point map mapping from covering_space to base_space, and a section_map method that returns a point map from base_space to covering_space.

This function can be used as a class decorator, in which case the covering_class doesn’t need to be specified.

Parameters
  • base_space (type) – a Space subclass.

  • covering_space (type) – a Space subclass.

  • covering_class (type) – the Covering subclass to register as handling (base_space, covering_space) mappings.

Returns

returns covering_class such that it can be used as a decorator.

Return type

type

crikit.covering.covering.get_map(source: crikit.cr.types.Space, target: crikit.cr.types.Space, **covering_params) crikit.cr.types.PointMap[source]

Creates a map from source to target using the Covering registry. If a mapping from source to target is not found in the registry, it raises an exception.

If source and target are DirectSums of multiple spaces, they must each have the same number of spaces. In that case, the i-th subspace in the source is mapped individually to the i-th subspace in the target space.

Any additional kwargs are passed to the covering constructor.

Parameters
  • source (Space) – The source space.

  • target (Space) – The target space.

  • **covering_params – Parameters to pass to the covering class constructor.

Returns

A map from the source space to the target space.

Return type

PointMap

crikit.covering.covering.get_composite_cr(*args, **covering_params) crikit.cr.types.PointMap[source]

Creates a map that composes the given PointMaps and/or Spaces by using covering maps to convert between Spaces.

Any kwargs are passed to get_map().

For example, get_composite_cr(space1, cr1, cr2, space2) returns a CompositePointMap that does the following:

  • takes input in space1,

  • converts it to input to cr1,

  • applies cr1

  • converts the output to the input space of cr2

  • applies cr2

  • and then converts the output to space2.

The conversion point maps are created using the get_map() function.

Parameters
  • *args (Space or PointMap) – the PointMaps that should be applied and any desired spaces to convert to.

  • **covering_params – Parameters to pass to the covering class constructor.

Returns

A map from the source space to the target space.

Return type

PointMap

Covering Implementations

crikit.covering.ufl

class crikit.covering.ufl.Numpy_UFLFunctionSpace_Covering(base_space, covering_space=None, domain=None, quad_params=None, **covering_params)[source]

Bases: crikit.covering.covering.Covering

class crikit.covering.ufl.UFLFunctionSpace_UFLExpr_Covering(base_space, covering_space, **covering_params)[source]

Bases: crikit.covering.covering.Covering

class crikit.covering.ufl.UFLFunctionSpace_UFLFunctionSpace_Covering(base_space, covering_space, **covering_params)[source]

Bases: crikit.covering.covering.Covering

class crikit.covering.ufl.Numpy_To_UFLFunctionSpace(source, target, quad_space=None, quad_params=None)[source]

Bases: crikit.cr.types.PointMap

This class is a point map that maps a NumPy array to a UFL function space.

The constructor creates a quadrature space for the target space, using the given quadrature parameters. The __call__ method sticks the given array into the quadrature space and projects it to the target function space.

class crikit.covering.ufl.UFLExprSpace_To_Numpy(source, target=None, quad_space=None, quad_params=None, domain=None)[source]

Bases: crikit.cr.types.PointMap

This class is a point map that maps an expression to a NumPy array.

The constructor creates a quadrature space for the source space, using the given quadrature parameters. The __call__ method projects the input into the quadrature space and extracts the values as a NumPy array.

Parameters
  • source (UFLExprSpace or UFLFunctionSpace) – the UFL space to use as input.

  • target (Ndarrays, optional) – the target space to map to.

  • quad_space (optional) – the finite-element quadrature space to interpolate to.

  • quad_params (dict, optional) – parameters for the quadrature space.

  • domain (optional) – the UFL domain for the quadrature space.

Invariants

class crikit.invariants.TensorType(order, shape, symmetric, antisymmetric, name)[source]

Bases: tuple

property antisymmetric

Alias for field number 3

static from_array(X, symmetric=False, antisymmetric=False, name: str = '')[source]

Creates a TensorType representing a particular array

Parameters
  • X – The array

  • symmetric (bool, optional) – Is the array symmetric? defaults to False

  • antisymmetric (bool, optional) – Is the array antisymmetric? defaults to False

  • name (str, optional) – The name of the tensor, defaults to ‘’

Returns

A TensorType representing X

Return type

TensorType

get_array_like() jax._src.numpy.lax_numpy.ndarray[source]

Constructs an example array with the right shape and symmetry

Returns

a tensor with the right shape and symmetry

Return type

jnp.ndarray

get_symmetrizer()[source]

Returns a function that takes in a tensor of order self.order and makes it symmetric.

Returns

The symmetrizer for a tensor of order self.order

Return type

function

static make_antisymmetric(order: int, spatial_dims: int, name: str = '')[source]

Returns a TensorType representing an antisymmetric tensor

Parameters
  • order (int) – The order of the tensor

  • spatial_dims (int) – How many spatial dimensions?

  • name (str, optional) – The name of the tensor, defaults to ‘’

Returns

A TensorType representing an antisymmetric order-order tensor in spatial_dims spatial dimensions.

Return type

TensorType

static make_scalar(name: str = '')[source]

Returns a TensorType representing a scalar

Parameters

name (str, optional) – The name of the scalar, defaults to ‘’

Returns

A TensorType representing a scalar

Return type

TensorType

static make_symmetric(order: int, spatial_dims: int, name: str = '')[source]

Returns a TensorType representing a symmetric tensor

Parameters
  • order (int) – The order of the tensor

  • spatial_dims (int) – How many spatial dimensions?

  • name (str, optional) – The name of the tensor, defaults to ‘’

Returns

A TensorType representing a symmetric order-order tensor in spatial_dims spatial dimensions.

Return type

TensorType

static make_vector(spatial_dims: int, name: str = '')[source]

Returns a TensorType representing a vector

Parameters
  • spatial_dims (int) – The number of spatial dimensions

  • name (str, optional) – The name of the vector, defaults to ‘’

Returns

A TensorType representing a vector

Return type

TensorType

property name

Alias for field number 4

property order

Alias for field number 0

property shape

Alias for field number 1

property symmetric

Alias for field number 2

tensor_space_dimension() int[source]

Returns the dimension (as a vector space) of the tensor space containing tensors of this shape.

Returns

dimension of the tensor space containing this TensorType

Return type

int

zeros_like() jax._src.numpy.lax_numpy.ndarray[source]

Returns an array of zeros of the shape self.shape.

Returns

jnp.zeros(self.shape)

Return type

jnp.ndarray

class crikit.invariants.LeviCivitaType(order)[source]

Bases: crikit.invariants.invariants.TensorType

A class that represents the Levi-Civita tensor

crikit.invariants.levi_civita(n)[source]

Returns the Levi-Civita pseudotensor in n dimensions.

Parameters

n (int) – the number of dimensions

Returns

The Levi-Civita pseudotensor in n spatial dimensions

Return type

np.ndarray

crikit.invariants.type_from_array(X, rtol: float = 1e-05, name: str = '') crikit.invariants.invariants.TensorType[source]

Like TensorType.from_array(), but tries to detect if the matrix is symmetric or asymmetric.

Parameters
  • X (Union[jnp.ndarray]) – An array (JAX or numpy)

  • rtol (float) – The relative tolerance to use when determining if X is symmetric, antisymmetric, or both, defaults to 1.0e-5

  • name (str, optional) – The name of the tensor, defaults to ‘’

Returns

An appropriate TensorType instance

Return type

TensorType

class crikit.invariants.InvariantInfo(spatial_dims: int, input_types: Tuple[crikit.invariants.invariants.TensorType, ...], output_type: crikit.invariants.invariants.TensorType)[source]

Bases: tuple

A class that contains relevant information for computing invariants. For example, for a hemitropic CR in 3 spatial dimensions taking a symmetric and an antisymmetric tensor as inputs and outputs a symmetric second order tensor:

info = InvariantInfo(3,
                     (TensorType.make_symmetric(2,3),
                      TensorType.make_antisymmetric(2,3),
                      LeviCivitaType(3)
                     ),
                     TensorType.make_symmetric(2,3)
                    )
static from_arrays(output_example, *args, **kwargs)[source]

Constructs an InvariantInfo from arrays representing the output and inputs

Parameters
  • output_example (Union[jnp.ndarray]) – an array of the correct shape and symmetry/antisymmetry of the desired output

  • args (Iterable[Union[jnp.ndarray]]) – an example of each of the input tensors

  • rtol (float) – the relative tolerance for detecting symmetry/antisymmetry and the Levi-Civita symbol, defaults to 1.0e-5

Returns

an InvariantInfo with the correct spatial dims (inferred from the first argument) and correct input_types for your inputs and output

Return type

InvariantInfo

get_group_symbol(sanitize_input_types: bool = False)[source]

Returns a symbol representing the group this instance represents.

Parameters

sanitize_input_types (bool, optional) – if True, this function will also return the input types without the Levi-Civita symbol, if it exists, default False

Returns

a string whose value is either \(O(2)\),:math:SO(2),:math:O(3), or \(SO(3)\)

Return type

str

property input_types

Alias for field number 1

property output_type

Alias for field number 2

property spatial_dims

Alias for field number 0

crikit.invariants.get_invariant_functions(info: crikit.invariants.invariants.InvariantInfo, suppress_warning_print: Optional[bool] = False, fail_on_warning: Optional[bool] = False)[source]

This function builds two functions, one to compute the scalar invariants, and one to compute the form invariants.

Parameters
  • info (InvariantInfo) – an InvariantInfo instance

  • suppress_warning_print (bool, optional) – if True, don’t print out warnings (this typically would be used if you get a warning about scalar or form-invariants not being available for a specific subset of the input types, and you know that this isn’t a problem, e.g. because no such invariants exist for that subset), defaults to False

  • fail_on_warning (bool, optional) – if True, warnings become exceptions. Useful if you know that you should not get a warning for your inputs, and want to make sure that nothing changes in a way that breaks that assumption., defaults to False

Returns

a tuple of two functions, the first of which generates the input scalar invariants (and places them into a jax.numpy.ndarray), and the second of which generates the output form-invariant basis.

Return type

tuple

crikit.invariants.get_invariant_descriptions(info: crikit.invariants.invariants.InvariantInfo, suppress_warning_print: Optional[bool] = False, fail_on_warning: Optional[bool] = False, html: Optional[bool] = None, ipython: Optional[bool] = None)[source]

This function builds a string description of the scalar and form invariants that you would get from get_invariant_functions() with the same arguments you pass in here.

Parameters
  • info (InvariantInfo) – an InvariantInfo instance

  • suppress_warning_print (bool, optional) – if True, don’t print out warnings (this typically would be used if you get a warning about scalar or form-invariants not being available for a specific subset of the input types, and you know that this isn’t a problem, e.g. because no such invariants exist for that subset), defaults to False

  • fail_on_warning (bool, optional) – if True, warnings become exceptions. Useful if you know that you should not get a warning for your inputs, and want to make sure that nothing changes in a way that breaks that assumption., defaults to False

  • html (bool, optional) – Return HTML instead of a plain string description? Useful for use inside Jupyter notebooks. Defaults to False

  • ipython (bool, optional) – Is this being used in ipython mode? (e.g. in a Jupyter notebook) By default, tries to guess whether or not you are. If the default behavior is undesirable, set this parameter manually.

Returns

a string describing the invariants

Return type

str

crikit.invariants.register_invariant_functions(info: crikit.invariants.invariants.InvariantInfo, scalar_invariant_func, form_invariant_func, overwrite_existing=False, nojit=False)[source]

Register a scalar and form-invariant computing function for a given InvariantInfo.

Parameters
  • info (InvariantInfo) – an InvariantInfo containing the relevant information about the inputs and outputs of functions with this symmetry.

  • scalar_invariant_func (Callable) – a function that returns a single jax.numpy.ndarray contaning the scalar invariants for the inputs.

  • form_invariant_func (Callable) – a function that returns a Python list of jax.numpy.ndarray instances representing the form-invariants for the inputs

  • overwrite_existing (Union[bool, Tuple[bool, bool]], optional) – if True, and the InvariantInfo you pass describes an existing set of invariants, replace those with your function. You should NEVER set this to True unless you really know what you’re doing. If you want to overwrite one function but not the other (e.g. insert a form-invariant for a scenario where the scalar invariant already exists), you can also pass a pair of bools, one for the scalar invariant function and one for the form invariant function. defaults to False

  • nojit (bool, optional) – if True, do NOT call jax.jit() on scalar_invariant_func() or form_invariant_func(), defaults to False

Returns

None, makes your functions available to get_invariant_functions()

crikit.invariants.near(val, to, rtol=1e-05)[source]

Returns True if val and to are within relative tolerance rtol and False otherwise

Parameters
  • val (np.ndarray) – a value

  • to (np.ndarray) – is val close to this?

  • rtol (float, optional) – Relative tolerance, defaults to 1.0e-5

Returns

are val and to within rtol?

Return type

bool

crikit.invariants.symm(x)[source]

Symmetrizes the input

Parameters

x (Union[np.ndarray,onp.ndarray]) – a 2-d array to symmetrize

Returns

A symmetric (and doubled) version of x

Return type

Union[np.ndarray,onp.ndarray]

crikit.invariants.antisymm(x)[source]

Antisymmetrizes the input

Parameters

x (Union[np.ndarray,onp.ndarray]) – a 2-d array to antisymmetrize

Returns

An antisymmetric (and doubled) version of x

Return type

Union[np.ndarray,onp.ndarray]

FE Support

crikit.cr.fe.assemble_with_cr(form: ufl.form.Form, cr: crikit.cr.types.PointMap, arg: Union[ufl.core.expr.Expr, Tuple[ufl.core.expr.Expr]], out_terms: Union[ufl.coefficient.Coefficient, Tuple[ufl.coefficient.Coefficient]], quad_params: Optional[dict] = None, force_explicit: Optional[bool] = False, return_all: Optional[bool] = False, **kwargs)[source]

Substitute the output of cr applied to arg in the form and then assemble the form.

Given a UFL form that contains out_terms, this function calculates cr_out = cr(arg), replaces each term in out_terms with the corresponding term in cr_out, and assembles the resulting form.

If cr can’t take arg directly as input, or if it doesn’t output Functions, or if force_explicit=True, then get_composite_cr() used to to map arg to the CR’s input space and to map the CR’s output to a UFL function space.

Parameters
  • form (Form) – the form

  • cr (PointMap) – a point map whose input is compatible with arg and whose output is compatible with out_terms.

  • arg (Expr or tuple[Expr]) – input expressions to the CR

  • out_terms (Coefficient or tuple[Coefficient]) – terms in the form that will be replaced by the output of the CR

  • quad_params (dict, optional) – parameters for the quadrature space

  • force_explicit (bool, optional) – pass True to force get_composite_cr to be called.

  • return_all (bool, optional) – Set to True to have the unassemled form and composite CR returned.

  • **kwargs – passed through to the assemble function of the backend

Returns

the output of pyadjoint_utils.fenics_adjoint.assemble

If return_all is true, additionally returns:

Return type

float or Coefficient

Observers

class crikit.observer.AdditiveRandomFunction(V: dolfin.function.functionspace.FunctionSpace, distribution: Optional[str] = 'normal', seed: Optional[int] = 0, **kwargs)[source]

A class representing a random function on a mesh that adds a sample from its distribution to the inputs (i.e. an additive noise model). In other words, if X is your UFL input and Y is the random variable this class represents, then this class returns X + Y.

__call__(ufl_input: fenics_adjoint.types.function.Function) fenics_adjoint.types.function.Function[source]

Generate and add the noise.

Parameters

ufl_input (Function) – The input

Returns

The input plus some random noise

Return type

Function

__init__(V: dolfin.function.functionspace.FunctionSpace, distribution: Optional[str] = 'normal', seed: Optional[int] = 0, **kwargs)[source]
Parameters
  • V (FunctionSpace) – The FunctionSpace in which to generate noise

  • distribution (str, optional) – A string describing the distribution of this function. Use the static method AdditiveRandomFunction.available_distributions() to see available distributions, defaults to ‘normal’

  • seed (int, optional) – the seed for jax.random, defaults to 0

  • kwargs (dict, optional) – Keyword arguments (e.g. parameter values) to be passed on to the distribution. For example, if distribution is ‘gamma’, you should pass a = value_of_a, since jax.random.gamma takes a parameter named a.

static get_available_params() dict[source]

Returns a dictionary mapping distribution names to the names of parameters for that distribution (pass as kwargs to the constructor for this class). For example, if you’re using a pareto distribution, you’ll also want to pass b=your_b_value to the constructor of this class.

Currently, that dictionary is

{'bernoulli' : 'p',
    'beta' : 'a, b',
    'dirichlet' : 'alpha',
    'double_sided_maxwell' : 'loc, scale',
    'gamma' : 'a',
    'multivariate_normal' : 'mean, cov',
    'normal' : 'mu, std',
    'pareto' : 'b',
    'poisson' : 'lam',
    't' : 'df',
    'truncated_normal' : 'lower, upper',
    'uniform' : 'minval, maxval',
    'weibull' : 'concentration'
 }
class crikit.observer.SubdomainObserver(mesh: fenics_adjoint.types.mesh.Mesh, subdomain: dolfin.cpp.mesh.SubDomain)[source]
__call__(u: fenics_adjoint.types.function.Function) fenics_adjoint.types.function.Function[source]

Compute the observation by zeroing out values off of the SubDomain.

Parameters

u (Function) – The Function to observe

Returns

The observed u

Return type

Function

__init__(mesh: fenics_adjoint.types.mesh.Mesh, subdomain: dolfin.cpp.mesh.SubDomain)[source]

An Observer that zeros all values of a Function outside of a certain SubDomain.

Parameters
  • mesh ((Mesh)) – The mesh

  • subdomain (SubDomain) – The subdomain

Returns

The SubdomainObserver

Return type

SubdomainObserver

class crikit.observer.SurfaceObserver(de)[source]

Groups

class crikit.group.SpecialOrthoGroup(dim: int)[source]

This class represents the special orthogonal group, which is the group of arbitrary rotations without reflections.

Parameters

dim (int) – The dimension of the group, which is usually 2 or 3.

apply(sample: pyadjoint.overloaded_type.OverloadedType, tensors: List[numpy.ndarray], offset=1)[source]

This method applies a sample of the group to a group of tensors.

Each tensor is assumed to be in the form of a NumPy array. The first axis of each array is ignored, and the transformation is applied to the remaining axes.

Parameters
  • sample – a sample from the group.

  • tensors – a tensor or list of tensors to apply the transformation to.

Returns

returns the transformed tensor(s).

Return type

type

sample() pyadjoint.overloaded_type.OverloadedType[source]

This method returns a sample of the group.

crikit.group.get_einsum_args(rank: int, contravariants: List[int], covariants: List[int], basis: numpy.ndarray, inv_basis: numpy.ndarray, offset: Optional[int] = 0)[source]

Generates arguments for a NumPy’s einsum function so that calling that function with these arguments and a tensor will perform a tensor coordinate transformation.

The tensor is of the given rank, with the given covariant axes and contravariants axes. Each specified axis should be in the range [0, rank-1].

If the coordinate transformation is being performed on a set of tensors stored together in the same NumPy array, the first axis (or first few axes) of the array is not part of the tensor definition. The offset argument specifies how many of the first axes to ignore. For example, if N tensors each with shape S are stored in an array of shape (N, S), the offset should be set to 1.

Parameters
  • rank (int) – the rank of the tensor to be transformed.

  • contravariants (list) – the axes that transform contravariantly.

  • covariants (list) – the axes that transform covariantly.

  • basis – the coordinate basis to use for the transform.

  • inv_basis – the inverse of the coordinate basis, used for the contravariant axes.

  • offset (int) – the first axes will be unchanged.

Returns

returns a tuple of args to be used with NumPy’s einsum function.

Return type

tuple

class crikit.group.GroupAction(group, space, sample=None)[source]

Bases: crikit.cr.types.PointMap

This point map applies the action of a group to the given inputs.

Parameters
  • group – The group that will be used to perform the action.

  • space (Space) – The space that the group transformations will be applied to.

  • sample – a sample of the group that will be used if no sample is specified in the __call__() method.

__call__(inputs, sample=None)[source]

Applies the group transformation to the given inputs using the given group sample. If no sample is given, the default sample is used.

Logging

crikit.logging.set_log_level(level)[source]

Set the log level for CRIKit and FEniCS. For example,

import crikit
crikit.logging.set_log_level(crikit.logging.WARNING)