CRIKit Core#
Core Classes#
CRIKit works with two main types: Space
and PointMap
.
A CR is formulated as a PointMap that maps from a source Space to a target Space.
That means that a CR can be implemented easily as long as you know what its input and output spaces are.
Space#
A Space should be able to do three things:
generate a point in the space,
test whether a point is in the space,
and get the shape of a point in the space.
Here’s an example Space that could be used to represent real numbers.
from crikit.cr.types import Space
class RealSpace(Space):
def point(self):
return 0
def is_point(self, point):
return isinstance(point, (int, float))
def shape(self):
return ()
PointMap#
The only method that a point map must define is __call__()
.
Here’s an example point map that raises a number to a constant power. (I’m using
the pre-defined RR
to represent the real numbers space).
from crikit.cr.types import PointMap
from crikit.cr.stdnumeric import RR
class ConstPow(PointMap):
def __init__(self, p):
self._p = p
source_space = RR
target_space = RR
super().__init__(source_space, target_space)
def __call__(self, point):
return point**self._p
point_map = ConstPow(3)
assert point_map(4) == 64
Space Builders#
These classes exist to make it easy to compose new Spaces by combining existing spaces.
DirectSum#
If a point map takes multiple arguments or returns multiple outputs, then it can
use a DirectSum
as the input or output space, respectively.
Note that the point map __call__()
function
should only take one argument, so if the point map acts on multiple arguments,
they must be passed together in a tuple or a list.
from crikit.cr.types import PointMap
from crikit.cr.space_builders import DirectSum
from crikit.cr.stdnumeric import RR
class Pow(PointMap):
def __init__(self):
source_space = DirectSum(RR, RR)
target_space = RR
super().__init__(source_space, target_space)
def __call__(self, point):
x, p = point
return x**p
point_map = Pow()
assert point_map((4, 3)) == 64
Multiset#
A Multiset
represents a finite repetition of a base space, where each
element of the Multiset can be treated equivalently.
For example, the Pow point map above cannot use a Multiset because exchanging
x
and p
changes the values. In the example below, I use a
Multiset because the inputs x, y, z
can all be interchanged without
changing the result of applying the point map.
from crikit.cr.types import PointMap
from crikit.cr.space_builders import Multiset
from crikit.cr.stdnumeric import RR
class SumThree(PointMap):
def __init__(self):
source_space = Multiset(RR, 3)
target_space = RR
super().__init__(source_space, target_space)
def __call__(self, point):
x, y, z = point
return x + y + z
point_map = SumThree()
assert point_map((1, 2, 3)) == 6
assert point_map((3, 2, 1)) == 6
Map Builders#
These classes exist to make it easy to create point maps by building on top of existing point maps or functions.
Callable#
Callable
creates a point map from any Python object
that is callable()
. This includes functions, methods, classes, and
any objects with the __call__
method defined.
To construct a Callable instance, you must give the constructor the input and output spaces, as well as the callable itself. The example below constructs a point map equivalent to the Pow example from above.
from crikit.cr.map_builders import Callable
from crikit.cr.space_builders import DirectSum
from crikit.cr.stdnumeric import RR
source = DirectSum(RR, RR)
target = RR
def func(point): return point[0]**point[1]
point_map = Callable(source, target, func)
assert point_map((4, 3)) == 64
Since most functions aren’t written to take all the parameters in a single
tuple, the Callable constructor has a bare
parameter. If you set it to
true, then the point map will unpack the point before giving it to the
underlying callable.
from crikit.cr.map_builders import Callable
from crikit.cr.space_builders import DirectSum
from crikit.cr.stdnumeric import RR
source = DirectSum(RR, RR)
target = RR
def func(x, y): return x**y
point_map = Callable(source, target, func, bare=True)
assert point_map((4, 3)) == 64
Parametric#
The Parametric
class takes an existing point
map and makes some of its arguments optional by providing them with default
values. The input space of the point map is reduced by removing inputs at
specified indices.
In the example below, I create a point map equivalent to the {ref}ConstPow example <const-pow-example>
from above.
from crikit.cr.map_builders import Callable, Parametric
from crikit.cr.space_builders import DirectSum
from crikit.cr.stdnumeric import RR
from operator import pow
pow_point_map = Callable(DirectSum(RR, RR), RR, pow, bare=True)
assert pow_point_map((4, 3)) == 64
# This removes the argument at index 1 from the input space and gives it a
# default value of 3.
const_pow_point_map = Parametric(pow_point_map, 1, 3)
# Now the exponent isn't part of the input for the point map.
assert const_pow_point_map((4,)) == 64
# But it can still be passed in separately through the params kwarg.
assert const_pow_point_map((4,), params=2) == 16
The bare
parameter can be used to get rid of the extra parentheses around the 4 if
the resulting point map only has one argument.
const_pow_point_map = Parametric(pow_point_map, 1, 3, bare=True)
assert const_pow_point_map(4) == 64
assert const_pow_point_map(4, params=2) == 16
AugmentPointMap#
The AugmentPointMap
class takes an existing point map and turns some
its keyword arguments into mandatory arguments. These arguments become part of
the input space of the point map. In order to correctly update the input space
of the point map, the Spaces of the keyword arguments must be specified.
In the example below, I create a point map with two keyword arguments. I move one parameter to the input space and then do it again with both parameters.
from crikit.cr.types import PointMap
from crikit.cr.map_builders import AugmentPointMap
from crikit.cr.stdnumeric import RR
class PowAdd(PointMap):
def __init__(self):
source_space = RR
target_space = RR
super().__init__(source_space, target_space)
def __call__(self, x, p=2, a=1):
return x**p + a
point_map = PowAdd()
assert point_map(4) == 17
assert point_map(4, p=3, a=0) == 64
# First I'll make the `p` argument part of the input space.
aug_map_p = AugmentPointMap(point_map, 'p', RR)
assert aug_map_p((4, 3)) == 65
assert aug_map_p((4, 3), a=0) == 64
# Now I'll make both `p` and `a` part of the input space.
aug_map_both = AugmentPointMap(point_map, ['p', 'a'], DirectSum(RR, RR))
assert aug_map_both((4, (3, 0))) == 64
# If I set bare=True, then I don't have to pass all the params as one tuple.
aug_map_bare = AugmentPointMap(point_map, ['p', 'a'], DirectSum(RR, RR), bare=True)
assert aug_map_bare((4, 3, 0)) == 64
CompositePointMap#
The CompositePointMap
class links a group of point maps sequentially,
so that the output of one map is fed into the input of the next. This is useful
for creating larger point maps from separate specialized maps.
In the example below, I create a point map that calculates \(x^2 + 1\) by using one point map to square the input and a separate point map to add one.
from crikit.cr.stdnumeric import RR
from crikit.cr.map_builders import Callable, CompositePointMap
def square(x): return x ** 2
def add_one(x): return x + 1
square_map = Callable(RR, RR, square)
add_one_map = Callable(RR, RR, add_one)
assert square_map(3) == 9
assert add_one_map(9) == 10
composite_map = CompositePointMap(square_map, add_one_map)
assert composite_map(3) == 10
ParallelPointMap#
The ParallelPointMap
class combines a group of point maps to accept all
of their inputs together and return all of their outputs together. This is useful
for creating larger point maps from separate specialized maps.
In the example below, I create a point map that calculates {math}f(x, y) = (x^2, y + 1)
by using one point map to square the first value and a separate
point map to add one to the second value.
from crikit.cr.stdnumeric import RR
from crikit.cr.map_builders import Callable, ParallelPointMap
def square(x): return x ** 2
def add_one(x): return x + 1
square_map = Callable(RR, RR, square)
add_one_map = Callable(RR, RR, add_one)
assert square_map(3) == 9
assert add_one_map(6) == 7
parallel_map = ParallelPointMap(square_map, add_one_map)
assert parallel_map((3, 6)) == (9, 7)
IdentityPointMap#
The IdentityPointMap
class creates a point map that directly returns
its input without changing it at all. It can be useful when constructing a
ParallelPointMap where one of the values shouldn’t be modified at all.
In the code below, I create a ParallelPointMap that transforms the triple \((x, y, z)\) into \((x^2, y + 1, z)\). The first two values are computed using the maps from the previous example, and the last value is passed through with no change using the IdentityPointMap.
from crikit.cr.stdnumeric import RR
from crikit.cr.map_builders import Callable, ParallelPointMap, IdentityPointMap
def square(x): return x ** 2
def add_one(y): return y + 1
square_map = Callable(RR, RR, square)
add_one_map = Callable(RR, RR, add_one)
identity_map = IdentityPointMap(RR, RR)
assert square_map(3) == 9
assert add_one_map(6) == 7
assert identity_map(10) == 10
parallel_map = ParallelPointMap(square_map, add_one_map, identity_map)
assert parallel_map((3, 6, 10)) == (9, 7, 10)
Covering#
This module exists to support a point map defined on one set of spaces to be
used on a compatible set of spaces. Conversions between spaces are representing
as Coverings. Each Covering
implementation defines two main methods:
one returns a PointMap from the covering space to the base space, and the other
returns a PointMap from the base space to the covering space.
There is a registration system to keep track of what spaces can be converted.
Each class that implements a Covering should call register_covering()
to
register what conversions that class can handle. By having this registry,
coverings can be looked up automatically as they are needed.
The function get_composite_cr()
is given a list of PointMaps and Spaces
and automatically inserts Covering maps to convert from one space to the next,
or from one PointMap’s target space to the next PointMap’s source space.
In the code below, I use FEniCS to create a mesh with a function defined on the
mesh. Then I create two versions of the p-Laplacian point map. One is
implemented in NumPy and the other is implemented in UFL. The UFL map can be
applied directly to my test input, but the NumPy requires evaluating the
function at quadrature points and putting the results into a NumPy array. The
Covering interface is designed to automatically handle this conversion through
the get_composite_cr()
function, as long as the appropriate quadrature
parameters were set with the set_default_covering_params()
function.
from crikit.fe import *
from crikit.fe_adjoint import *
from pyadjoint_utils import AdjFloat
from crikit.cr.numpy import CR_P_LaplacianNumpy
from crikit.cr.ufl import CR_P_Laplacian, UFLFunctionSpace
from crikit.covering import get_composite_cr, set_default_covering_params
# Set up the finite-element function spaces.
mesh = UnitSquareMesh(3, 3)
V = FunctionSpace(mesh, 'P', 1)
V_vec = FunctionSpace(mesh, 'RTE', 1)
out_function_space = UFLFunctionSpace(V_vec)
v = interpolate(Expression('x[0]*x[0] + x[1]*x[1]',degree=2), V)
g = grad(v)
# Create a CR implemented in NumPy.
p_np = AdjFloat(2.5)
cr_np = CR_P_LaplacianNumpy(p_np, input_u=False)
# Create a CR implemented in UFL.
p_ufl = Constant(p_np)
cr_ufl = CR_P_Laplacian(p_ufl, input_u=False)
# Give the Covering interface the parameters it needs to map from UFL to NumPy.
quad_params = {'quadrature_degree': 2}
domain = mesh.ufl_domain()
set_default_covering_params(domain=domain, quad_params=quad_params)
# Create point maps that have the same input and output spaces, but different
# implementations in the middle.
cr_np_composite = get_composite_cr(cr_ufl.source, cr_np, out_function_space)
cr_ufl_composite = get_composite_cr(cr_ufl.source, cr_ufl, out_function_space)
# The point maps give the same output even though they are implemented in
# different spaces.
sigma_ufl = cr_ufl_composite(g)
sigma_np = cr_np_composite(g)
assert errornorm(sigma_ufl, sigma_np) < 1e-7
Numpy-Like Libraries#
CRIKit is written such that users can select at run-time between JAX and PyTorch to perform numpy-like operations in CRIKit, such as ones that are found in crikit.invariants
and crikit.cr.cr.CR
, in addition to JIT compilation and vectorization. Note that currently, the torch
backend is experimental and not as well-optimized as the jax
backend, in particular related to its Jacobian assembly. Typically, users will control this through the backend
parameter of {meth}
crikit.cr.cr.CR.init. For example, these two
CR`s compute the same quantities, one using JAX and the other using PyTorch:
import crikit
import jax.numpy as jnp
import numpy as np
import torch
dims = 3
in_types = (crikit.TensorType.make_symmetric(2, dims),
crikit.TensorType.make_vector(dims),
crikit.LeviCivitaType(dims))
out_type = crikit.TensorType.make_symmetric(2, dims)
# for this problem, it turns out these are both the same, equal to 7
n_scalar_invts, n_form_invts = crikit.cr_function_shape(out_type, in_types)
assert n_scalar_invts == n_form_invts
theta_jax = crikit.array(jnp.arange(0.0, n_scalar_invts))
# this isn't a realistic material model, and is only intended to
# show how to use the jax and torch backends
def cr_fun_jax(scalar_invts, theta):
return (scalar_invts * jnp.sin(theta))
theta_torch = crikit.tensor(torch.arange(0.0, n_scalar_invts))
def cr_fun_torch(scalar_invts, theta):
return (scalar_invts * torch.sin(theta))
cr_jax = crikit.CR(out_type, in_types, cr_fun_jax, params=(theta_jax,), backend='jax')
cr_torch = crikit.CR(out_type, in_types, cr_fun_torch, params=(theta_torch,), backend='torch')
n_qpoint = 100
# generate some random tensors as input
jax_inputs = tuple(crikit.array(np.random.randn(n_qpoint, *x.shape)) for x in in_types[:-1])
# the same random tensors converted to torch
torch_inputs = tuple(map(lambda x: crikit.tensor(x.unwrap()), jax_inputs))
assert jnp.allclose(cr_torch(torch_inputs).detach().numpy(), cr_jax(jax_inputs).unwrap())
Invariants#
Many real-life materials posess certain symmetries. For example, water is typically modeled as an isotropic fluid, meaning that from the point of view of an observer, any rotation of the coordinate axes leads to no change at all in the observed behavior. A wooden dowel with the grain running in the longer direction might be transversely isotropic, meaning that rotations about the axis along which the grain runs should also lead to no change in observed behavior. Mathematically, symmetries are described by groups, and all physical symmetry groups are subgroups of the group of rotations and flips in three dimensions (i.e. subgroups of \(O(3)\)). Specifically, physical symmetry groups are so-called compact point groups, that is, subgroups of the orthogonal group whose underlying sets are compact [Zheng, 1994].
Given some finite number of finite-rank tensors \(A_1, A_2, \ldots , A_n\) and some point group \(G\) with a left action on \(\{A_1, A_2, \ldots , A_n\}\), \(G\) is said to be characterized by \(\{A_1, A_2, \ldots , A_n\}\) if every orthogonal tensor \(g\) (i.e. element of some point group representation) is a member of \(G\) if and only if
We call \(A_1, A_2, \ldots, A_n\) structural tensors of \(G\). A 1994 theorem of Zheng and Boehler [Zheng and Boehler, 1994] guarantees that if \(G\) is any physical symmetry group (i.e. compact point group) and \(A_1, A_2, \ldots, A_n\) characterize \(G\), then any \(G\)-equivariant function \(F: T\to V\) for tensor spaces \(T\) and \(V\) can be written as an isotropic function
(where \(F_I\) is said isotropic function). As such, CRIKit represents equivariant tensor functions of inputs as isotropic tensor functions
of inputs and structural tensors. So, if you’re modeling the behavior of a wooden dowel with grain running along the vector field \(\mathbf k\) and you want an equivariant function with respect to the group of rotations about \(\mathbf k\), just add \(\mathbf k\) to
the list of your input tensors (and pass it as an input to the returned functions, or if you’re accessing crikit.invariants
through crikit.cr.cr.CR
, pass it as an input when you call the CR).
The Invariants module is used to store and retrieve functions that
compute scalar and form-invariants for given combinations of inputs and outputs.
Users typically do not need to know about the Invariants module, since the
crikit.cr.cr.CR
uses this internally to create a CR whose output is equivariant
with respect to the symmetry group.
Given a group \(G\) and tensor space \(T\), a function \(\varphi : T\to \mathbb{R}\) is an invariant (a.k.a. scalar invariant) of \(G\) if and only if for every \(t\in T\) and \(g\in G\), \(\varphi(g\cdot t) = \varphi(t)\). Given another tensor space \(V\), a function \(\Psi : T\to V\) is a form-invariant of \(G\) if and only if for every \(t\in T\) and \(g\in G\), \(\Psi(g\cdot t) = g\cdot \Psi(t)\). Another 1994 theorem of Zheng and Boehler (which we will refer to as the Zheng-Boehler theorem) that extends the well-known 1964 Wineman-Pipkin theorem guarantees that we can automatically generate complete function bases of scalar invariants \(\{\varphi_j\}_{j=1}^m\) and form-invariants \(\{\Psi_i\}_{i=1}^k\) such that any \(O(d)\)-equivariant tensor-valued function \(F : T\to V\) (in \(d\) spatial dimensions) can be represented as [Zheng and Boehler, 1994]
where the \(f_i\) are scalar-valued functions.
In the function get_invariant_functions()
, we generate these scalar invariants \(\{\varphi_j\}_{j=1}^m\) and form-invariants \(\{\Psi_i\}_{i=1}^k\) from the tables in [Zheng, 1994] and pack them into two functions, one that computes the scalar invariants and places them into a one-dimensional array and another that stacks the form-invariants into a three-dimensional jax.numpy.ndarray
.
In other words, in the code the user interfaces with, the \(\Psi_i\) are all computed by a single function (and returned stacked in a single array), the \(\varphi_j\) are all computed by another function (and returned in a single one-dimensional array), and the \(f_i\) are all computed by a third function (the one the user supplies), and returned in another one-dimensional array. We call the function that computes \(f_i\) the inner function for a CR.
This representation shown above is called the Wineman-Pipkin representation of functions that are equivariant under a point group, and as is stated above, it has two very powerful conclusions. Not only is it possible to represent any function whatsoever that is equivariant under any point group in this manner, but regardless of what the inner function actually is, and regardless of what properties it may or may not have, the “outer function” \(F\) is guaranteed to be equivariant. Those familiar with deep learning might call the scalar invariants features, except instead of learning an approximation of those optimal features as is often done in the deep learning community, we can compute them exactly (to within machine precision) and efficiently.
Note that \(T\) might be a graded tensor algebra (there are no such restrictions in the Zheng-Boehler theorem)—in other words, it might contain multiple independent sub-spaces, such as a scalar space, a vector space and a symmetric rank-two tensor space, all direct-summed together into one space, and in general for the crikit.cr.cr.CR
class, \(T\) is a direct sum of the spaces represented by the various TensorType
s that make up the input_types
parameter of the crikit.cr.cr.CR
constructor.
To see exactly what the invariants and form-invariants CRIKit generates actually are for any given combination of input and output TensorType
s, you can use the function get_invariant_descriptions()
.
TensorType#
The class TensorType
is a NamedTuple
that contains the information
about a tensor that we need to determine the invariants, which is the order, the
shape, and whether the tensor is symmetric, antisymmetric, or neither. If you have an
example of a tensor that you want to get a TensorType
descriptor for, the
easiest way to do that is with from_array()
if you know the symmetry
beforehand, or with type_from_array()
if you do not. However, you typically do
not need to construct TensorType
instances yourself if you have appropriate
example tensors, because they are primarily used as members of InvariantInfo
instances. The Levi-Civita tensor is represented with LeviCivitaType
, a
subclass of TensorType
.
InvariantInfo#
There are two ways to construct an
InvariantInfo
– which itself is another NamedTuple
– the first
being to construct it directly from its members, which are an integer representing the
number of spatial dimensions, a tuple of TensorType
instances representing the
inputs to the CR (including any structural tensor(s)), and a TensorType
representing the output of the CR. The second way is to use
InvariantInfo.from_arrays()
, which functions much like type_from_array()
.
Once you have an InvariantInfo
, you can call get_invariant_functions()
on it, and get a pair of functions in return. The first function computes the
scalar invariants for your specified inputs, and the second function computes the form
invariants for the specified inputs. Note that you MUST pass the input arguments to
those functions in the same order as the inputs were specified in the
InvariantInfo
. Note that if you pass the Levi-Civita pseudotensor to
InvariantInfo.from_arrays()
(or explicitly specify a LeviCivitaType
in
the TensorType
list passed to the InvariantInfo
constructor), you
should NOT pass the Levi-Civita tensor to the returned functions. They will account
for its presence without it being passed, and they will expect that you do not pass
it as an argument.
Registering Your Own Invariants#
While we have implemented scalar and form-invariant functions for the most common subsets of
input/output TensorType
s, we cannot currently algorithmically generate invariants for
subsets other than what we have manually implemented. Additionally, sometimes certain bases
for the scalar and form-invariants have better inference properties than others at the
cost of computational efficiency. For example, if we have a single rank-two symmetric
input A
and likewise a single symmetric rank-two output in two dimensions, the scalar invariants are
trace(A)
and trace(A @ A)
, and the form-invariants are I_2
and A
. However, this
basis can suffer from collinearity when A
is small, and for a slightly higher computational
cost, this basis can be replaced with the eigenvalues of A
, and the corresponding form-invariants
can be replaced with the outer products of eigenvectors of A
with themselves. You can replace
the default implementation of the invariants for any given set of input/output types
(or provide an implementation for a subset that we have not implemented invariants for yet)
using the register_invariant_functions()
function in crikit.invariants
.
All invariant functions have the signature f(*args, backend)
where *args
are the input tensors
and backend
is a numpy backend object that implements the numpy API in a generic manner (i.e.
so that we can switch between jax
and torch
at will. For example, the eigenvector/eigenvalue
invariants we discussed above would be implemented as
- ..testcode
import crikit
- def scalar_invts(A, backend):
return backend.linalg.eigvalsh(A)
- def form_invts(A, backend):
return backend.linalg.eigh(A)[1]
- info = crikit.InvariantInfo(2, (crikit.TensorType.make_symmetric(2, 2),),
crikit.TensorType.make_symmetric(2, 2))
crikit.register_invariant_functions(info, scalar_invts, form_invts, overwrite_existing=True)
If you only intend to use one numpy backend (e.g. only jax
or only torch
), you can ignore the backend
parameter
to the invariant functions, but you should use it if you intend to use the registered invariants with different numpy backends.
Putting It Together: Invariant CRs#
The class crikit.cr.cr.CR
represents a constitutive relation that is invariant under some
compact point group. The scalar and form-invariants are generated automatically
by crikit.invariants.get_invariant_functions()
, so the only things you need to supply are
a sequence (e.g. a list or tuple) of crikit.invariants.TensorType
s representing the
inputs to your CR (including any structural tensor(s)), a crikit.invariants.TensorType
representing the output to your CR, a JAX-compatible (i.e. can be traced by JAX) inner function,
and any parameters the function takes. To figure out what shape your function needs to be for a given set
of inputs and outputs, use crikir.cr.cr.cr_function_shape()
, which returns a tuple containing the
size of the (1-dimensional) input array of scalar invariants, and the size of the (1-dimensional)
array of outputs of your function that the CR
expects.
Note that typically, you will have a so-called pointwise CR, meaning that your CR function
is written as a function of the inputs at any given point in a domain. This is how CRs are written
down when we’re writing math, but when we’re running a simulation, we typically want to evaluate
the CR at many points (all the quadrature points in a domain or in a batch of elements). In order to
achieve high performance and avoid slow Python loops, we use vmap()
and jit()
functions from the chosen numpy backend to both vectorize the application of a CR over a whole mesh and to JIT-compile said CR function
to native code. In this case, only jax
and torch
backends will actually perform this compilation and vectorization, while the plain numpy
backend will attempt to manually emulate that vectorization in pure Python. Use of these is controlled through the vmap
and nojit
parameters to the crikit.cr.cr.CR
class’s __init__()
method, which default to True
and False
respectively.
You can use them to indicate that you want a CR that is evaluated one point at a time (vmap=False
),
or one that cannot, for whatever reason (e.g. it uses Python-value-dependent control flow) be JIT-compiled. Note that the behavior of the JIT compiler can be controlled further with other arguments to
the crikit.cr.cr.CR
class; see its API documentation for details.
- Zhe94(1,2)
Q.S. Zheng. Theory of representations for tensor functions—a unified invariant approach to constitutive equations. Appl. Mech. Rev., 47:546 – 487, 1994.
- ZB94(1,2)
Q.S. Zheng and J.P. Boehler. The description, classification, and reality of material and physical symmetries. Acta Mechanica, 102:73–89, 1994.