CRIKit Core#

Core Classes#

CRIKit works with two main types: Space and PointMap. A CR is formulated as a PointMap that maps from a source Space to a target Space. That means that a CR can be implemented easily as long as you know what its input and output spaces are.


A Space should be able to do three things:

  • generate a point in the space,

  • test whether a point is in the space,

  • and get the shape of a point in the space.

Here’s an example Space that could be used to represent real numbers.

from import Space

class RealSpace(Space):
    def point(self):
        return 0

    def is_point(self, point):
        return isinstance(point, (int, float))

    def shape(self):
        return ()


The only method that a point map must define is __call__().

Here’s an example point map that raises a number to a constant power. (I’m using the pre-defined RR to represent the real numbers space).

from import PointMap
from import RR

class ConstPow(PointMap):
    def __init__(self, p):
        self._p = p
        source_space = RR
        target_space = RR
        super().__init__(source_space, target_space)

    def __call__(self, point):
        return point**self._p

point_map = ConstPow(3)
assert point_map(4) == 64

Space Builders#

These classes exist to make it easy to compose new Spaces by combining existing spaces.


If a point map takes multiple arguments or returns multiple outputs, then it can use a DirectSum as the input or output space, respectively.

Note that the point map __call__() function should only take one argument, so if the point map acts on multiple arguments, they must be passed together in a tuple or a list.

from import PointMap
from import DirectSum
from import RR

class Pow(PointMap):
    def __init__(self):
        source_space = DirectSum(RR, RR)
        target_space = RR
        super().__init__(source_space, target_space)

    def __call__(self, point):
        x, p = point
        return x**p

point_map = Pow()
assert point_map((4, 3)) == 64


A Multiset represents a finite repetition of a base space, where each element of the Multiset can be treated equivalently.

For example, the Pow point map above cannot use a Multiset because exchanging x and p changes the values. In the example below, I use a Multiset because the inputs x, y, z can all be interchanged without changing the result of applying the point map.

from import PointMap
from import Multiset
from import RR

class SumThree(PointMap):
    def __init__(self):
        source_space = Multiset(RR, 3)
        target_space = RR
        super().__init__(source_space, target_space)

    def __call__(self, point):
        x, y, z = point
        return x + y + z

point_map = SumThree()
assert point_map((1, 2, 3)) == 6
assert point_map((3, 2, 1)) == 6

Map Builders#

These classes exist to make it easy to create point maps by building on top of existing point maps or functions.


Callable creates a point map from any Python object that is callable(). This includes functions, methods, classes, and any objects with the __call__ method defined.

To construct a Callable instance, you must give the constructor the input and output spaces, as well as the callable itself. The example below constructs a point map equivalent to the Pow example from above.

from import Callable
from import DirectSum
from import RR

source = DirectSum(RR, RR)
target = RR
def func(point): return point[0]**point[1]
point_map = Callable(source, target, func)
assert point_map((4, 3)) == 64

Since most functions aren’t written to take all the parameters in a single tuple, the Callable constructor has a bare parameter. If you set it to true, then the point map will unpack the point before giving it to the underlying callable.

from import Callable
from import DirectSum
from import RR

source = DirectSum(RR, RR)
target = RR
def func(x, y): return x**y
point_map = Callable(source, target, func, bare=True)
assert point_map((4, 3)) == 64


The Parametric class takes an existing point map and makes some of its arguments optional by providing them with default values. The input space of the point map is reduced by removing inputs at specified indices.

In the example below, I create a point map equivalent to the {ref}ConstPow example <const-pow-example> from above.

from import Callable, Parametric
from import DirectSum
from import RR
from operator import pow

pow_point_map = Callable(DirectSum(RR, RR), RR, pow, bare=True)
assert pow_point_map((4, 3)) == 64

# This removes the argument at index 1 from the input space and gives it a
# default value of 3.
const_pow_point_map = Parametric(pow_point_map, 1, 3)

# Now the exponent isn't part of the input for the point map.
assert const_pow_point_map((4,)) == 64

# But it can still be passed in separately through the params kwarg.
assert const_pow_point_map((4,), params=2) == 16

The bare parameter can be used to get rid of the extra parentheses around the 4 if the resulting point map only has one argument.

const_pow_point_map = Parametric(pow_point_map, 1, 3, bare=True)
assert const_pow_point_map(4) == 64
assert const_pow_point_map(4, params=2) == 16


The AugmentPointMap class takes an existing point map and turns some its keyword arguments into mandatory arguments. These arguments become part of the input space of the point map. In order to correctly update the input space of the point map, the Spaces of the keyword arguments must be specified.

In the example below, I create a point map with two keyword arguments. I move one parameter to the input space and then do it again with both parameters.

from import PointMap
from import AugmentPointMap
from import RR

class PowAdd(PointMap):
    def __init__(self):
        source_space = RR
        target_space = RR
        super().__init__(source_space, target_space)

    def __call__(self, x, p=2, a=1):
        return x**p + a

point_map = PowAdd()
assert point_map(4) == 17
assert point_map(4, p=3, a=0) == 64

# First I'll make the `p` argument part of the input space.
aug_map_p = AugmentPointMap(point_map, 'p', RR)
assert aug_map_p((4, 3)) == 65
assert aug_map_p((4, 3), a=0) == 64

# Now I'll make both `p` and `a` part of the input space.
aug_map_both = AugmentPointMap(point_map, ['p', 'a'], DirectSum(RR, RR))
assert aug_map_both((4, (3, 0))) == 64

# If I set bare=True, then I don't have to pass all the params as one tuple.
aug_map_bare = AugmentPointMap(point_map, ['p', 'a'], DirectSum(RR, RR), bare=True)
assert aug_map_bare((4, 3, 0)) == 64


The CompositePointMap class links a group of point maps sequentially, so that the output of one map is fed into the input of the next. This is useful for creating larger point maps from separate specialized maps.

In the example below, I create a point map that calculates \(x^2 + 1\) by using one point map to square the input and a separate point map to add one.

from import RR
from import Callable, CompositePointMap

def square(x): return x ** 2
def add_one(x): return x + 1
square_map = Callable(RR, RR, square)
add_one_map = Callable(RR, RR, add_one)

assert square_map(3) == 9
assert add_one_map(9) == 10

composite_map = CompositePointMap(square_map, add_one_map)
assert composite_map(3) == 10


The ParallelPointMap class combines a group of point maps to accept all of their inputs together and return all of their outputs together. This is useful for creating larger point maps from separate specialized maps.

In the example below, I create a point map that calculates {math}f(x, y) = (x^2, y + 1) by using one point map to square the first value and a separate point map to add one to the second value.

from import RR
from import Callable, ParallelPointMap

def square(x): return x ** 2
def add_one(x): return x + 1
square_map = Callable(RR, RR, square)
add_one_map = Callable(RR, RR, add_one)

assert square_map(3) == 9
assert add_one_map(6) == 7

parallel_map = ParallelPointMap(square_map, add_one_map)
assert parallel_map((3, 6)) == (9, 7)


The IdentityPointMap class creates a point map that directly returns its input without changing it at all. It can be useful when constructing a ParallelPointMap where one of the values shouldn’t be modified at all.

In the code below, I create a ParallelPointMap that transforms the triple \((x, y, z)\) into \((x^2, y + 1, z)\). The first two values are computed using the maps from the previous example, and the last value is passed through with no change using the IdentityPointMap.

from import RR
from import Callable, ParallelPointMap, IdentityPointMap

def square(x): return x ** 2
def add_one(y): return y + 1
square_map = Callable(RR, RR, square)
add_one_map = Callable(RR, RR, add_one)
identity_map = IdentityPointMap(RR, RR)

assert square_map(3) == 9
assert add_one_map(6) == 7
assert identity_map(10) == 10

parallel_map = ParallelPointMap(square_map, add_one_map, identity_map)
assert parallel_map((3, 6, 10)) == (9, 7, 10)


This module exists to support a point map defined on one set of spaces to be used on a compatible set of spaces. Conversions between spaces are representing as Coverings. Each Covering implementation defines two main methods: one returns a PointMap from the covering space to the base space, and the other returns a PointMap from the base space to the covering space.

There is a registration system to keep track of what spaces can be converted. Each class that implements a Covering should call register_covering() to register what conversions that class can handle. By having this registry, coverings can be looked up automatically as they are needed.

The function get_composite_cr() is given a list of PointMaps and Spaces and automatically inserts Covering maps to convert from one space to the next, or from one PointMap’s target space to the next PointMap’s source space.

In the code below, I use FEniCS to create a mesh with a function defined on the mesh. Then I create two versions of the p-Laplacian point map. One is implemented in NumPy and the other is implemented in UFL. The UFL map can be applied directly to my test input, but the NumPy requires evaluating the function at quadrature points and putting the results into a NumPy array. The Covering interface is designed to automatically handle this conversion through the get_composite_cr() function, as long as the appropriate quadrature parameters were set with the set_default_covering_params() function.

from crikit.fe import *
from crikit.fe_adjoint import *
from pyadjoint_utils import AdjFloat
from import CR_P_LaplacianNumpy
from import CR_P_Laplacian, UFLFunctionSpace
from crikit.covering import get_composite_cr, set_default_covering_params

# Set up the finite-element function spaces.
mesh = UnitSquareMesh(3, 3)
V = FunctionSpace(mesh, 'P', 1)
V_vec = FunctionSpace(mesh, 'RTE', 1)
out_function_space = UFLFunctionSpace(V_vec)

v = interpolate(Expression('x[0]*x[0] + x[1]*x[1]',degree=2), V)
g = grad(v)

# Create a CR implemented in NumPy.
p_np = AdjFloat(2.5)
cr_np = CR_P_LaplacianNumpy(p_np, input_u=False)

# Create a CR implemented in UFL.
p_ufl = Constant(p_np)
cr_ufl = CR_P_Laplacian(p_ufl, input_u=False)

# Give the Covering interface the parameters it needs to map from UFL to NumPy.
quad_params = {'quadrature_degree': 2}
domain = mesh.ufl_domain()
set_default_covering_params(domain=domain, quad_params=quad_params)

# Create point maps that have the same input and output spaces, but different
# implementations in the middle.
cr_np_composite = get_composite_cr(cr_ufl.source, cr_np, out_function_space)
cr_ufl_composite = get_composite_cr(cr_ufl.source, cr_ufl, out_function_space)

# The point maps give the same output even though they are implemented in
# different spaces.
sigma_ufl = cr_ufl_composite(g)
sigma_np = cr_np_composite(g)
assert errornorm(sigma_ufl, sigma_np) < 1e-7

Numpy-Like Libraries#

CRIKit is written such that users can select at run-time between JAX and PyTorch to perform numpy-like operations in CRIKit, such as ones that are found in crikit.invariants and, in addition to JIT compilation and vectorization. Note that currently, the torch backend is experimental and not as well-optimized as the jax backend, in particular related to its Jacobian assembly. Typically, users will control this through the backend parameter of {meth} For example, these two CR`s compute the same quantities, one using JAX and the other using PyTorch:

import crikit
import jax.numpy as jnp
import numpy as np
import torch

dims = 3
in_types = (crikit.TensorType.make_symmetric(2, dims),
out_type = crikit.TensorType.make_symmetric(2, dims)

# for this problem, it turns out these are both the same, equal to 7
n_scalar_invts, n_form_invts = crikit.cr_function_shape(out_type, in_types)
assert n_scalar_invts == n_form_invts
theta_jax = crikit.array(jnp.arange(0.0, n_scalar_invts))
# this isn't a realistic material model, and is only intended to
# show how to use the jax and torch backends
def cr_fun_jax(scalar_invts, theta):
    return (scalar_invts * jnp.sin(theta))

theta_torch = crikit.tensor(torch.arange(0.0, n_scalar_invts))
def cr_fun_torch(scalar_invts, theta):
    return (scalar_invts * torch.sin(theta))

cr_jax = crikit.CR(out_type, in_types, cr_fun_jax, params=(theta_jax,), backend='jax')
cr_torch = crikit.CR(out_type, in_types, cr_fun_torch, params=(theta_torch,), backend='torch')

n_qpoint = 100
# generate some random tensors as input
jax_inputs = tuple(crikit.array(np.random.randn(n_qpoint, *x.shape)) for x in in_types[:-1])

# the same random tensors converted to torch
torch_inputs = tuple(map(lambda x: crikit.tensor(x.unwrap()), jax_inputs))

assert jnp.allclose(cr_torch(torch_inputs).detach().numpy(), cr_jax(jax_inputs).unwrap())


Many real-life materials posess certain symmetries. For example, water is typically modeled as an isotropic fluid, meaning that from the point of view of an observer, any rotation of the coordinate axes leads to no change at all in the observed behavior. A wooden dowel with the grain running in the longer direction might be transversely isotropic, meaning that rotations about the axis along which the grain runs should also lead to no change in observed behavior. Mathematically, symmetries are described by groups, and all physical symmetry groups are subgroups of the group of rotations and flips in three dimensions (i.e. subgroups of \(O(3)\)). Specifically, physical symmetry groups are so-called compact point groups, that is, subgroups of the orthogonal group whose underlying sets are compact [Zheng, 1994].

Given some finite number of finite-rank tensors \(A_1, A_2, \ldots , A_n\) and some point group \(G\) with a left action on \(\{A_1, A_2, \ldots , A_n\}\), \(G\) is said to be characterized by \(\{A_1, A_2, \ldots , A_n\}\) if every orthogonal tensor \(g\) (i.e. element of some point group representation) is a member of \(G\) if and only if

\[\begin{aligned} g\cdot A_1 &= A_1\\ g\cdot A_2 &= A_2\\ &\vdots\\ g\cdot A_n &= A_n \end{aligned} \]

We call \(A_1, A_2, \ldots, A_n\) structural tensors of \(G\). A 1994 theorem of Zheng and Boehler [Zheng and Boehler, 1994] guarantees that if \(G\) is any physical symmetry group (i.e. compact point group) and \(A_1, A_2, \ldots, A_n\) characterize \(G\), then any \(G\)-equivariant function \(F: T\to V\) for tensor spaces \(T\) and \(V\) can be written as an isotropic function

\[F(t) = F_I(t, A_1, A_2, \ldots, A_n) \]

(where \(F_I\) is said isotropic function). As such, CRIKit represents equivariant tensor functions of inputs as isotropic tensor functions of inputs and structural tensors. So, if you’re modeling the behavior of a wooden dowel with grain running along the vector field \(\mathbf k\) and you want an equivariant function with respect to the group of rotations about \(\mathbf k\), just add \(\mathbf k\) to the list of your input tensors (and pass it as an input to the returned functions, or if you’re accessing crikit.invariants through, pass it as an input when you call the CR).

The Invariants module is used to store and retrieve functions that compute scalar and form-invariants for given combinations of inputs and outputs. Users typically do not need to know about the Invariants module, since the uses this internally to create a CR whose output is equivariant with respect to the symmetry group.

Given a group \(G\) and tensor space \(T\), a function \(\varphi : T\to \mathbb{R}\) is an invariant (a.k.a. scalar invariant) of \(G\) if and only if for every \(t\in T\) and \(g\in G\), \(\varphi(g\cdot t) = \varphi(t)\). Given another tensor space \(V\), a function \(\Psi : T\to V\) is a form-invariant of \(G\) if and only if for every \(t\in T\) and \(g\in G\), \(\Psi(g\cdot t) = g\cdot \Psi(t)\). Another 1994 theorem of Zheng and Boehler (which we will refer to as the Zheng-Boehler theorem) that extends the well-known 1964 Wineman-Pipkin theorem guarantees that we can automatically generate complete function bases of scalar invariants \(\{\varphi_j\}_{j=1}^m\) and form-invariants \(\{\Psi_i\}_{i=1}^k\) such that any \(O(d)\)-equivariant tensor-valued function \(F : T\to V\) (in \(d\) spatial dimensions) can be represented as [Zheng and Boehler, 1994]

\[F(t) = \sum\limits_{i=1}^k f_i(\varphi_1(t)),\ldots , \varphi_m(t)) \Psi_i(t) \]

where the \(f_i\) are scalar-valued functions.

In the function get_invariant_functions(), we generate these scalar invariants \(\{\varphi_j\}_{j=1}^m\) and form-invariants \(\{\Psi_i\}_{i=1}^k\) from the tables in [Zheng, 1994] and pack them into two functions, one that computes the scalar invariants and places them into a one-dimensional array and another that stacks the form-invariants into a three-dimensional jax.numpy.ndarray.

In other words, in the code the user interfaces with, the \(\Psi_i\) are all computed by a single function (and returned stacked in a single array), the \(\varphi_j\) are all computed by another function (and returned in a single one-dimensional array), and the \(f_i\) are all computed by a third function (the one the user supplies), and returned in another one-dimensional array. We call the function that computes \(f_i\) the inner function for a CR.

This representation shown above is called the Wineman-Pipkin representation of functions that are equivariant under a point group, and as is stated above, it has two very powerful conclusions. Not only is it possible to represent any function whatsoever that is equivariant under any point group in this manner, but regardless of what the inner function actually is, and regardless of what properties it may or may not have, the “outer function” \(F\) is guaranteed to be equivariant. Those familiar with deep learning might call the scalar invariants features, except instead of learning an approximation of those optimal features as is often done in the deep learning community, we can compute them exactly (to within machine precision) and efficiently.

Note that \(T\) might be a graded tensor algebra (there are no such restrictions in the Zheng-Boehler theorem)—in other words, it might contain multiple independent sub-spaces, such as a scalar space, a vector space and a symmetric rank-two tensor space, all direct-summed together into one space, and in general for the class, \(T\) is a direct sum of the spaces represented by the various TensorTypes that make up the input_types parameter of the constructor.

To see exactly what the invariants and form-invariants CRIKit generates actually are for any given combination of input and output TensorTypes, you can use the function get_invariant_descriptions().


The class TensorType is a NamedTuple that contains the information about a tensor that we need to determine the invariants, which is the order, the shape, and whether the tensor is symmetric, antisymmetric, or neither. If you have an example of a tensor that you want to get a TensorType descriptor for, the easiest way to do that is with from_array() if you know the symmetry beforehand, or with type_from_array() if you do not. However, you typically do not need to construct TensorType instances yourself if you have appropriate example tensors, because they are primarily used as members of InvariantInfo instances. The Levi-Civita tensor is represented with LeviCivitaType, a subclass of TensorType.


There are two ways to construct an InvariantInfo – which itself is another NamedTuple – the first being to construct it directly from its members, which are an integer representing the number of spatial dimensions, a tuple of TensorType instances representing the inputs to the CR (including any structural tensor(s)), and a TensorType representing the output of the CR. The second way is to use InvariantInfo.from_arrays(), which functions much like type_from_array().

Once you have an InvariantInfo, you can call get_invariant_functions() on it, and get a pair of functions in return. The first function computes the scalar invariants for your specified inputs, and the second function computes the form invariants for the specified inputs. Note that you MUST pass the input arguments to those functions in the same order as the inputs were specified in the InvariantInfo. Note that if you pass the Levi-Civita pseudotensor to InvariantInfo.from_arrays() (or explicitly specify a LeviCivitaType in the TensorType list passed to the InvariantInfo constructor), you should NOT pass the Levi-Civita tensor to the returned functions. They will account for its presence without it being passed, and they will expect that you do not pass it as an argument.

Registering Your Own Invariants#

While we have implemented scalar and form-invariant functions for the most common subsets of input/output TensorTypes, we cannot currently algorithmically generate invariants for subsets other than what we have manually implemented. Additionally, sometimes certain bases for the scalar and form-invariants have better inference properties than others at the cost of computational efficiency. For example, if we have a single rank-two symmetric input A and likewise a single symmetric rank-two output in two dimensions, the scalar invariants are trace(A) and trace(A @ A), and the form-invariants are I_2 and A. However, this basis can suffer from collinearity when A is small, and for a slightly higher computational cost, this basis can be replaced with the eigenvalues of A, and the corresponding form-invariants can be replaced with the outer products of eigenvectors of A with themselves. You can replace the default implementation of the invariants for any given set of input/output types (or provide an implementation for a subset that we have not implemented invariants for yet) using the register_invariant_functions() function in crikit.invariants. All invariant functions have the signature f(*args, backend) where *args are the input tensors and backend is a numpy backend object that implements the numpy API in a generic manner (i.e. so that we can switch between jax and torch at will. For example, the eigenvector/eigenvalue invariants we discussed above would be implemented as


import crikit

def scalar_invts(A, backend):

return backend.linalg.eigvalsh(A)

def form_invts(A, backend):

return backend.linalg.eigh(A)[1]

info = crikit.InvariantInfo(2, (crikit.TensorType.make_symmetric(2, 2),),

crikit.TensorType.make_symmetric(2, 2))

crikit.register_invariant_functions(info, scalar_invts, form_invts, overwrite_existing=True)

If you only intend to use one numpy backend (e.g. only jax or only torch), you can ignore the backend parameter to the invariant functions, but you should use it if you intend to use the registered invariants with different numpy backends.

Putting It Together: Invariant CRs#

The class represents a constitutive relation that is invariant under some compact point group. The scalar and form-invariants are generated automatically by crikit.invariants.get_invariant_functions(), so the only things you need to supply are a sequence (e.g. a list or tuple) of crikit.invariants.TensorType s representing the inputs to your CR (including any structural tensor(s)), a crikit.invariants.TensorType representing the output to your CR, a JAX-compatible (i.e. can be traced by JAX) inner function, and any parameters the function takes. To figure out what shape your function needs to be for a given set of inputs and outputs, use, which returns a tuple containing the size of the (1-dimensional) input array of scalar invariants, and the size of the (1-dimensional) array of outputs of your function that the CR expects.

Note that typically, you will have a so-called pointwise CR, meaning that your CR function is written as a function of the inputs at any given point in a domain. This is how CRs are written down when we’re writing math, but when we’re running a simulation, we typically want to evaluate the CR at many points (all the quadrature points in a domain or in a batch of elements). In order to achieve high performance and avoid slow Python loops, we use vmap() and jit() functions from the chosen numpy backend to both vectorize the application of a CR over a whole mesh and to JIT-compile said CR function to native code. In this case, only jax and torch backends will actually perform this compilation and vectorization, while the plain numpy backend will attempt to manually emulate that vectorization in pure Python. Use of these is controlled through the vmap and nojit parameters to the class’s __init__() method, which default to True and False respectively. You can use them to indicate that you want a CR that is evaluated one point at a time (vmap=False), or one that cannot, for whatever reason (e.g. it uses Python-value-dependent control flow) be JIT-compiled. Note that the behavior of the JIT compiler can be controlled further with other arguments to the class; see its API documentation for details.


Q.S. Zheng. Theory of representations for tensor functions—a unified invariant approach to constitutive equations. Appl. Mech. Rev., 47:546 – 487, 1994.


Q.S. Zheng and J.P. Boehler. The description, classification, and reality of material and physical symmetries. Acta Mechanica, 102:73–89, 1994.