import jax
from jax import jit
import sys
import importlib
from functools import partial
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple, List, Tuple, Optional, Union, Iterable, TypeVar
from crikit.logging import logger
from .utils import set_backend
from .utils import (
get_backend,
symm,
antisymm,
commutator_action,
anticommutator_action,
scalar_triple_prod,
powerset,
levi_civita,
_eps_vec_action,
axial_vector,
_tprod,
symm_q4,
symm_q3,
tA,
tv,
Tv,
TinnerS,
TcontrS,
TW,
tbrace,
_3d_rotation_matrix,
near,
is_symm,
is_antisymm,
)
from pyadjoint_utils.jax_adjoint import ndarray as jax_ndarray
try:
from pyadjoint_utils.torch_adjoint import Tensor as torch_tensor_t
except ImportError:
class torch_tensor_t:
pass
from pyadjoint_utils.numpy_backend import (
get_default_backend,
get_backend,
set_default_backend,
)
import inspect
from operator import itemgetter
import itertools
ndarray = TypeVar("ndarray")
[docs]class TensorType(NamedTuple):
order: int = 0
shape: tuple = ()
symmetric: bool = False
antisymmetric: bool = False
name: str = ""
[docs] @staticmethod
def make_scalar(name: str = ""):
"""Returns a TensorType representing a scalar
:param name: The name of the scalar, defaults to ''
:type name: str, optional
:return: A TensorType representing a scalar
:rtype: TensorType
"""
return TensorType(0, (), False, False, name)
[docs] @staticmethod
def make_vector(spatial_dims: int, name: str = ""):
"""Returns a TensorType representing a vector
:param spatial_dims: The number of spatial dimensions
:type spatial_dims: int
:param name: The name of the vector, defaults to ''
:type name: str, optional
:return: A TensorType representing a vector
:rtype: TensorType
"""
return TensorType(1, (spatial_dims,), False, False, name)
[docs] @staticmethod
def make_symmetric(order: int, spatial_dims: int, name: str = ""):
"""Returns a TensorType representing a symmetric tensor
:param order: The order of the tensor
:type order: int
:param spatial_dims: How many spatial dimensions?
:type spatial_dims: int
:param name: The name of the tensor, defaults to ''
:type name: str, optional
:return: A TensorType representing a symmetric order-``order`` tensor in
`spatial_dims` spatial dimensions.
:rtype: TensorType
"""
return TensorType(order, order * (spatial_dims,), True, False, name)
[docs] @staticmethod
def make_antisymmetric(order: int, spatial_dims: int, name: str = ""):
"""Returns a :class:`TensorType` representing an antisymmetric tensor
:param order: The order of the tensor
:type order: int
:param spatial_dims: How many spatial dimensions?
:type spatial_dims: int
:param name: The name of the tensor, defaults to ''
:type name: str, optional
:return: A TensorType representing an antisymmetric order-``order`` tensor in
`spatial_dims` spatial dimensions.
:rtype: TensorType
"""
return TensorType(order, order * (spatial_dims,), False, True, name)
[docs] @staticmethod
def from_array(X, symmetric=False, antisymmetric=False, name: str = ""):
"""Creates a :class:`TensorType` representing a particular array
:param X: The array
:type x: Union[jnp.ndarray,np.ndarray,pyadjoint_utils.jax_adjoint.ndarray]
:param symmetric: Is the array symmetric? defaults to False
:type symmetric: bool, optional
:param antisymmetric: Is the array antisymmetric? defaults to False
:type antisymmetric: bool, optional
:param name: The name of the tensor, defaults to ''
:type name: str, optional
:returns: A :class:`TensorType` representing ``X``
:rtype: TensorType
"""
shp = X.shape
if symmetric and antisymmetric:
raise RuntimeError(
"Cannot create a TensorType from an array that is both symmetric and antisymmetric! (That's a contradiction; pass True to only ONE of symmetric or antisymmetric in TensorType.from_array)"
)
return TensorType(len(shp), shp, symmetric, antisymmetric, name)
[docs] def get_symmetrizer(self):
"""Returns a function that takes in a tensor of order ``self.order`` and
makes it symmetric.
:return: The symmetrizer for a tensor of order ``self.order``
:rtype: function
"""
if self.order <= 1:
return lambda x: x
elif self.order == 2:
return symm
elif order == 3:
return symm_q3
elif order == 4:
return symm_q4
[docs] def zeros_like(self) -> jnp.ndarray:
"""Returns an array of zeros of the shape ``self.shape``.
:returns: jnp.zeros(self.shape)
:rtype: jnp.ndarray
"""
return jnp.zeros(shape=self.shape)
[docs] def tensor_space_dimension(self) -> int:
"""Returns the dimension (as a vector space) of the tensor space containing
tensors of this shape.
:returns: dimension of the tensor space containing this TensorType
:rtype: int
"""
N = max(self.shape)
q = self.order
if self.symmetric:
if q == 2:
return (N * (N + 1)) // 2
raise NotImplementedError
if self.antisymmetric:
if q == 2:
return (N * (N - 1)) // 2
raise NotImplementedError
return np.prod(self.shape)
def __hash__(self):
return hash(tuple((self.order, self.shape, self.symmetric, self.antisymmetric)))
def __ge__(self, other):
if isinstance(other, TensorType):
return self._selfge(other)
elif isinstance(other, tuple):
return True
elif isinstance(other, int):
return True
else:
return False
def __lt__(self, other):
return not self.__ge__(other)
def _selfge(self, other):
if self.order > other.order:
return True
elif self.order == other.order:
if self.symmetric:
if other.antisymmetric:
return True
elif other.symmetric:
return True
else:
return False
elif self.antisymmetric:
if other.symmetric:
return False
elif other.antisymmetric:
return True
else:
return False
else:
return False
else:
return False
[docs] def get_array_like(self, backend=None) -> jnp.ndarray:
"""
Constructs an example array with the right shape and symmetry
:return: a tensor with the right shape and symmetry
:rtype: jnp.ndarray
"""
if backend is None:
backend = get_default_backend()
if self.symmetric and self.order == 2:
return backend.eye(max(self.shape))
elif self.order == 1:
return backend.zeros(self.shape)
arr = backend.zeros(self.shape, dtype=jnp.float32)
upper_right_corner = tuple(max(0, s - 1) for s in self.shape)
arr = backend.index_update(arr, upper_right_corner, 1.0)
if self.symmetric:
return symm(arr, backend=backend)
elif self.antisymmetric:
return antisymm(arr, backend=backend)
else:
return arr
[docs]class LeviCivitaType(TensorType):
"""A class that represents the Levi-Civita tensor"""
def __new__(cls, order):
if order == 2:
order = 2
shape = levi_civita(2, get_default_backend()).shape
symmetric = False
antisymmetric = True
return super(LeviCivitaType, cls).__new__(
cls, order, shape, symmetric, antisymmetric
)
elif order == 3:
order = 3
shape = levi_civita(3, get_default_backend()).shape
symmetric = False
antisymmetric = True
return super(LeviCivitaType, cls).__new__(
cls, order, shape, symmetric, antisymmetric
)
else:
raise NotImplementedError
[docs]def type_from_array(
X, rtol: float = 1.0e-5, name: str = "", backend=None
) -> TensorType:
"""
Like :meth:`TensorType.from_array`, but tries to detect if the matrix is symmetric
or asymmetric.
:param X: An array (JAX or numpy)
:type X: Union[jnp.ndarray]
:param rtol: The relative tolerance to use when determining if X is symmetric,
antisymmetric, or both, defaults to 1.0e-5
:type rtol: float
:param name: The name of the tensor, defaults to ''
:type name: str, optional
:return: An appropriate TensorType instance
:rtype: TensorType
"""
if backend is None:
backend = get_default_backend()
if isinstance(X, jax_ndarray):
X = X.unwrap(True)
if backend.isscalar(X) or X.size == 1:
return TensorType.make_scalar(name)
shp = X.shape
order = len(shp)
if order == 1:
return TensorType(order, shp, False, False, name)
elif order == 2:
eps_ij = levi_civita(2, backend)
if shp == eps_ij.shape:
if near(X, levi_civita(2, backend), rtol):
return LeviCivitaType(order)
return TensorType(order, shp, is_symm(X, rtol), is_antisymm(X, rtol), name)
elif order == 3:
# NOTE: do we ever care whether or not a 3rd order tensor is totally antisymmetric? I don't think so, other than for comparison with the Levi-Civita tensor
eps_ijk = levi_civita(3, backend)
if shp == eps_ijk.shape:
if near(X, levi_civita(3, backend), rtol):
return LeviCivitaType(order)
return TensorType(order, shp, near(X, symm_q3(X), rtol), False, name)
else:
raise NotImplementedError
def evaluate_invariant_function(
phi_i, scalar_invariant_func, form_invariant_func, *input_tensors
):
scalar_invariants = scalar_invariant_func(*input_tensors)
if len(phi_i) != scalar_invariants.size:
raise ValueError(
f"Must pass {scalar_invariants.size} scalar functions, not {len(phi_i)}"
)
form_invariants = form_invariant_func(*input_tensors)
return _evaluate_invariant(form_invariant, scalar_invariants, phi_i)
def _evaluate_invariant(form_invariants, scalar_invariants, phi_i):
global numpy_backend
output_vec = numpy_backend.array([phi(scalar_invariants) for phi in phi_i])
return output_vec.T @ form_invariants
[docs]class InvariantInfo(NamedTuple):
"""A class that contains relevant information for computing invariants.
For example, for a hemitropic CR in 3 spatial dimensions taking a symmetric
and an antisymmetric tensor as inputs and outputs a symmetric second order
tensor:
::
info = InvariantInfo(3,
(TensorType.make_symmetric(2,3),
TensorType.make_antisymmetric(2,3),
LeviCivitaType(3)
),
TensorType.make_symmetric(2,3)
)
"""
spatial_dims: int
input_types: Tuple[TensorType, ...]
output_type: TensorType
[docs] def get_group_symbol(self, sanitize_input_types: bool = False):
"""
Returns a symbol representing the group this instance represents.
:param sanitize_input_types: if True, this function will also return
the input types without the Levi-Civita symbol, if it exists, default False
:type sanitize_input_types: bool, optional
:return: a string whose value is either :math:`O(2)`,:math:`SO(2)`,:math:`O(3)`, or :math:`SO(3)`
:rtype: str
"""
if self.spatial_dims <= 1:
raise ValueError(
f"Cannot get an orthogonal group symbol for {self.spatial_dims} spatial dimensions!"
)
contains_eps = False
eps_id = None
for i, tp in enumerate(self.input_types):
if isinstance(tp, LeviCivitaType):
contains_eps = True
eps_id = i
break
grp = None
if self.spatial_dims == 2:
grp = "SO(2)" if contains_eps else "O(2)"
elif self.spatial_dims == 3:
grp = "SO(3)" if contains_eps else "O(3)"
# NOTE: will not detect multiple instances of a LeviCivitaType. TODO: implement that.
if sanitize_input_types:
input_types = self.input_types
if contains_eps:
try:
if eps_id == len(self.input_types) - 1:
input_types = tuple(self.input_types[:-1])
elif eps_id == 0:
input_types = tuple(self.input_types[1:])
else:
input_types = tuple(
self.input_types[:i] + self.input_types[i + 1 :]
)
except IndexError:
pass
return grp, tuple(input_types)
else:
return grp
[docs] @staticmethod
def from_arrays(output_example, *args, **kwargs):
"""Constructs an :class:`InvariantInfo` from arrays representing the output and inputs
:param output_example: an array of the correct shape and symmetry/antisymmetry of the desired output
:type output_example: Union[jnp.ndarray]
:param args: an example of each of the input tensors
:type args: Iterable[Union[jnp.ndarray]]
:param rtol: the relative tolerance for detecting symmetry/antisymmetry and the Levi-Civita symbol, defaults to 1.0e-5
:type rtol: float
:return: an InvariantInfo with the correct spatial dims (inferred from the first argument)
and correct input_types for your inputs and output
:rtype: InvariantInfo
"""
rtol = kwargs.get("rtol", 1.0e-5)
types = [type_from_array(x, rtol=rtol) for x in args]
spatial_dims = max(types[0].shape)
for t in types:
if len(t.shape) > 0 and max(t.shape) != spatial_dims:
raise ValueError(
f"Input type {t} has {max(t.shape)} spatial dimensions, but the first input type has {spatial_dims} spatial dimensions"
)
output_type = type_from_array(output_example, rtol=rtol)
return InvariantInfo(spatial_dims, types, output_type)
# needed for static_argnums parameter of jax.jit
class HashableDict(dict):
def _key(self):
return tuple(sorted(self.items()))
def __hash__(self):
return hash(self._key())
def __setitem__(self, key, value):
raise TypeError(f"{type(self)} does not support assignment!")
[docs]def get_invariant_functions(
info: InvariantInfo,
suppress_warning_print: Optional[bool] = False,
fail_on_warning: Optional[bool] = False,
backend: Optional[str] = None,
):
"""
This function builds two functions, one to compute the scalar invariants,
and one to compute the form invariants.
:param info: an InvariantInfo instance
:type info: InvariantInfo
:param suppress_warning_print: if True, don't print out warnings
(this typically would be used if you get a warning about scalar or
form-invariants not being available for a specific subset of the input types,
and you know that this isn't a problem, e.g. because no such invariants
exist for that subset), defaults to False
:type suppress_warning_print: bool, optional
:param fail_on_warning: if True, warnings become exceptions. Useful if you
know that you should not get a warning for your inputs, and want to make
sure that nothing changes in a way that breaks that assumption., defaults
to False
:type fail_on_warning: bool, optional
:param backend: Which numpy backend to use? Must be one of 'jax', 'torch',
or 'numpy'. Defaults to whichever is the CRIKit default numpy backend
:type backend: str, optional
:return: a tuple of two functions, the first of which generates the
input scalar invariants (and places them into a backend numpy array),
and the second of which generates the output form-invariant basis.
:rtype: tuple
"""
if backend is None:
backend = get_default_backend()
else:
backend = get_backend(backend)
if not (isinstance(info, InvariantInfo)):
raise TypeError(
f"First parameter of get_scalar_invariant_function ({info}) is of type {type(info)}, not InvariantInfo!"
)
dims = info.spatial_dims
group, input_types = info.get_group_symbol(sanitize_input_types=True)
output_type = info.output_type
include_identity = output_type.order == 2 and output_type.symmetric
include_eps = dims == 2 and output_type.order == 2 and output_type.antisymmetric
old_input_types = input_types
# sort inputs
input_sort_idxs, input_types = _get_sorted_and_indices(input_types, reverse=True)
input_types = _tag_value_chains(tuple(input_types))
input_sort_map = {itype: i for (itype, i) in zip(input_types, input_sort_idxs)}
scalar_invts = []
form_invts = []
output_type = _nameless(output_type)
for sub_types in powerset(input_types, exclude_empty_set=True):
try:
new_sub_types = tuple(sorted(_untagged_tuple(sub_types), reverse=True))
si_func, fi_func = _get_invariant_functions(
dims, group, _untagged_tuple(new_sub_types, nameless=True), output_type
)
scalar_invts.append((si_func, sub_types))
if fi_func:
form_invts.append((fi_func, sub_types))
else:
if not suppress_warning_print:
logger.warning(
f"""we don't have functions to compute form-invariants for all subsets of your input basis!
Specifically, we do not have invariants for the input TensorType subset
{_untagged_tuple(sub_types)}.
Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't
always exist for any arbitrary subset of every set of input types--but if you expect invariants to
exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions.
"""
)
except ValueError as v:
if not suppress_warning_print:
logger.warning(
f"""we don't have functions to compute scalar invariants for all subsets of your input basis!
Specifically, we do not have invariants for the input TensorType subset
{_untagged_tuple(sub_types)}.
Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't
always exist for any arbitrary subset of every set of input types--but if you expect invariants to
exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions.
"""
)
if fail_on_warning:
raise v
def unified_scalar_invariant_func(imap, sinvts, backend, *args):
invts = []
N = len(sinvts)
for i in range(N):
si_func, sub_types = sinvts[i]
idx = [imap[sub_t] for sub_t in sub_types]
inputs = [args[i] for i in idx]
invts.append(_1darr(si_func(*inputs, backend), backend))
return backend.concatenate(invts)
scalar_invts = tuple(scalar_invts)
input_sort_map = HashableDict(input_sort_map)
def scalar_invariant_func(*args):
imap = input_sort_map
sinvts = scalar_invts
if len(args) != len(input_types):
raise ValueError(
f"Wrong number of arguments passed to scalar_invariant_func! Expected {len(input_types)}, but got {len(args)}!"
)
return unified_scalar_invariant_func(imap, sinvts, backend, *args)
def unified_form_invariant_func(imap, frminvts, backend, *args):
finvts = []
N = len(frminvts)
for i in range(N):
fi_func, sub_types = frminvts[i]
idx = [imap[sub_t] for sub_t in sub_types]
inputs = [args[i] for i in idx]
finvts += fi_func(*inputs, backend)
"""
NOTE: the above (making a list of indices then a list of inputs) works with jax.jit, but directly doing
fi_func(args[imap[sub_t]] for sub_t in sub_types))
doesn't work (JAX complains and throws an exception)
ALSO NOTE: putting the numpy_backend.stack in Python mode for now to deal with possibly appending the identity
"""
return finvts
form_invts = tuple(form_invts)
def form_invariant_func(*args):
imap = input_sort_map
frminvts = form_invts
has_identity = include_identity
has_eps = include_eps
if len(args) != len(input_types):
raise ValueError(
f"Wrong number of arguments passed to form_invariant_func! Expected {len(input_types)}, but got {len(args)}!"
)
extra_form_invts = []
if has_identity:
extra_form_invts.append(backend.eye(dims))
# elif has_eps:
# extra_form_invts.append(eps_ij)
if extra_form_invts:
return backend.stack(
extra_form_invts
+ unified_form_invariant_func(imap, frminvts, backend, *args)
)
return backend.stack(
unified_form_invariant_func(imap, frminvts, backend, *args)
)
return scalar_invariant_func, form_invariant_func
[docs]def get_invariant_descriptions(
info: InvariantInfo,
suppress_warning_print: Optional[bool] = False,
fail_on_warning: Optional[bool] = False,
html: Optional[bool] = None,
ipython: Optional[bool] = None,
backend: Optional[bool] = None,
):
"""
This function builds a string description of the scalar and form invariants that you would get from
:func:`get_invariant_functions` with the same arguments you pass in here.
:param info: an InvariantInfo instance
:type info: InvariantInfo
:param suppress_warning_print: if True, don't print out warnings
(this typically would be used if you get a warning about scalar or
form-invariants not being available for a specific subset of the input types,
and you know that this isn't a problem, e.g. because no such invariants
exist for that subset), defaults to False
:type suppress_warning_print: bool, optional
:param fail_on_warning: if True, warnings become exceptions. Useful if you
know that you should not get a warning for your inputs, and want to make
sure that nothing changes in a way that breaks that assumption., defaults
to False
:type fail_on_warning: bool, optional
:param html: Return HTML instead of a plain string description? Useful for use inside Jupyter notebooks.
Defaults to False
:type html: bool, optional
:param ipython: Is this being used in ipython mode? (e.g. in a Jupyter notebook) By default,
tries to guess whether or not you are. If the default behavior is undesirable, set this parameter manually.
:type ipython: bool, optional
:param backend: Which numpy backend to use? Must be one of 'jax', 'torch',
or 'numpy'. Defaults to whichever is the CRIKit default numpy backend
:type backend: str, optional
:return: a string describing the invariants
:rtype: str
"""
if not (isinstance(info, InvariantInfo)):
raise TypeError(
f"First parameter of get_scalar_invariant_function ({info}) is of type {type(info)}, not InvariantInfo!"
)
if backend is None:
backend = get_default_backend()
else:
backend = get_backend(backend)
class function_state:
n_rank_0 = 0
n_rank_1 = 0
n_rank_2_s = 0
n_rank_2_a = 0
n_rank_3 = 0
n_rank_0_n = (
0 # _n params count the number of this type that has a name already
)
n_rank_1_n = 0
n_rank_2_s_n = 0
n_rank_2_a_n = 0
n_rank_3_n = 0
if _executing_in_ipython():
ipython = True if ipython is None else ipython
# if we're in IPython mode, we definitely need HTML unless the parameter says otherwise
html = ipython if html is None else html
dims = info.spatial_dims
group, input_types = info.get_group_symbol(sanitize_input_types=True)
output_type = info.output_type
include_identity = output_type.order == 2 and output_type.symmetric
include_eps = dims == 2 and output_type.order == 2 and output_type.antisymmetric
input_state = function_state()
HEADER = (
_get_header_html(dims, input_types, input_state)
if html
else _get_header_plaintext(dims, input_types, input_state)
)
# sort inputs
input_sort_idxs, input_types = _get_sorted_and_indices(input_types, reverse=True)
input_types = _tag_value_chains(tuple(input_types))
input_sort_map = {itype: i for (itype, i) in zip(input_types, input_sort_idxs)}
scalar_invts = []
form_invts = []
output_type = _nameless(output_type)
for sub_types in powerset(input_types, exclude_empty_set=True):
try:
new_sub_types = tuple(sorted(sub_types, reverse=True))
si_func, fi_func = _get_invariant_functions(
dims, group, _untagged_tuple(new_sub_types, nameless=True), output_type
)
sub_types = tuple(sorted(sub_types, reverse=True))
scalar_invts.append((si_func, sub_types))
if fi_func:
form_invts.append((fi_func, sub_types))
else:
if not suppress_warning_print:
logger.warning(
f"""we don't have functions to compute form-invariants for all subsets of your input basis!
Specifically, we do not have invariants for the input TensorType subset
{_untagged_tuple(sub_types)}.
Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't
always exist for any arbitrary subset of every set of input types--but if you expect invariants to
exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions.
"""
)
except ValueError as v:
if not suppress_warning_print:
logger.warning(
f"""we don't have functions to compute scalar invariants for all subsets of your input basis!
Specifically, we do not have invariants for the input TensorType subset
{_untagged_tuple(sub_types)}.
Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't
always exist for any arbitrary subset of every set of input types--but if you expect invariants to
exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions.
"""
)
if fail_on_warning:
raise v
scalar_symbols = ["x", "y", "z"][input_state.n_rank_0_n :]
vector_symbols = ["v", "u"][input_state.n_rank_1_n :]
symm_symbols = ["A", "B", "C"][input_state.n_rank_2_s_n :]
antisymm_symbols = ["W", "V"][input_state.n_rank_2_a_n :]
r3_symbols = ["T", "S"][input_state.n_rank_3_n :]
# if you change this, also change _get_symbol_id() below
symbol_map = [
scalar_symbols,
vector_symbols,
symm_symbols,
antisymm_symbols,
r3_symbols,
]
# build scalar invariant description
scalar_invt_descrs = []
N = len(scalar_invts)
range_start = 0
inputs = []
for i in range(N):
si_func, sub_types = scalar_invts[i]
stride = _get_num_retvals(si_func, sub_types, backend)
if stride == 1:
prepend = f"{range_start} : "
else:
prepend = f"({range_start}:{range_start + stride - 1}) : "
if html:
prepend = "<code>" + prepend + "</code>"
range_start += stride
param_names = _infer_param_names(si_func, sub_types, symbol_map)
scalar_invt_descrs.append(
prepend + _format_invt_function(si_func, sub_types, symbol_map, html)
)
si_descr = _apply_final_formatting("\n\n".join(scalar_invt_descrs), html)
form_invt_descrs = []
range_start = 0
if include_identity:
form_invt_descrs.append(
f"<code>0 : [] -> I_{dims}</code>" if html else f"0 : [] -> I_{dims}"
)
range_start = 1
for i in range(len(form_invts)):
fi_func, sub_types = form_invts[i]
stride = _get_num_retvals_form_invt(fi_func, sub_types, backend)
if stride == 1:
prepend = f"{range_start} : "
else:
prepend = f"({range_start}:{range_start + stride - 1}) : "
if html:
prepend = "<code>" + prepend + "</code>"
range_start += stride
form_invt_descrs.append(
prepend + _format_invt_function(fi_func, sub_types, symbol_map, html)
)
fi_line = (
'<br><br><hr style="border: 1px dashed black"><br>'
if html
else "\n\n------------------------------------------------\n"
)
fi_line += "Form Invariants:\n\n"
fi_descr = _apply_final_formatting(fi_line + "\n\n".join(form_invt_descrs), html)
if html:
string = HEADER + si_descr + fi_descr + "</html>"
if ipython:
from IPython.core.display import HTML, display
if html:
string = HTML(string)
return display(string)
return HEADER + si_descr + fi_descr
def _executing_in_ipython():
try:
sh = get_ipython().__class__.__name__
return True
except NameError:
# get_ipython() doesn't exist, so we proabably aren't in an IPython environment
return False
def _apply_final_formatting(string, html):
if html:
return string.replace("\n", "<br>")
return string
def _get_header_html(dims, input_types, input_state):
style = """
<html>
<head>
<style>
table, th, td {
border: 1px solid black;
}
ul {
margin: 0;
}
ul.dashed {
list-style-type: none;
}
ul.dashed > li {
text-indent: -5px;
}
ul.dashed > li:before {
content: "-";
text-indent: -5px;
}
</style>
</head>
<body>
<p>Legend</p>
"""
string = f"""
<table style="width:100%">
<tr>
<th>Symbol(s)</th>
<th>Tensor Rank</th>
<th>Symmetric</th>
<th>Antisymmetric</th>
</tr>
<tr>
<td>x, y, z</td>
<td>0</td>
<td>N/A</td>
<td>N/A</td>
</tr>
<tr>
<td>v, u</td>
<td>1</td>
<td>N/A</td>
<td>N/A</td>
</tr>
<tr>
<td>A, B, C</td>
<td>2</td>
<td>Yes</td>
<td>No</td>
</tr>
<tr>
<td>W, V</td>
<td>2</td>
<td>No</td>
<td>Yes</td>
</tr>
<tr>
<td>T, S</td>
<td>3</td>
<td>Any</td>
<td>Any</td>
</tr>
</table>
<br><br>
Special symbols:<br>
<ul class="dashed">
<li> <code>I_{dims}</code> (rank-two identify tensor, A.K.A. identity matrix)</li>
</ul>
<br>
Operations:<br>
<ul class="dashed">
<li> <code>_tprod(x, y) = numpy_backend.tensordot(x, y, axes=0) (tensor product)</code></li> <br>
<li> <code>symm(x) = x + x.T</code></li><br>
<li> <code>antisymm(x) = x - x.T</code></li>
</ul>
<br>
<hr style="border: 1px dashed black">
<br>
Input tensors: <br>
<code>{_get_type_symbols(input_types, input_state)}</code>
<br>
<hr style="border: 1px dashed black">
<br>
Scalar invariants:<br><br>
"""
return style + string
def _get_header_plaintext(dims, input_types, input_state):
return f"""
Legend:
------------------------------------------------
Symbol(s) | Tensor Rank | Symmetric | Antisymmetric
------------------------------------------------
x, y, z | 0 | N/A | N/A
------------------------------------------------
v, u | 1 | N/A | N/A
------------------------------------------------
A, B, C | 2 | Yes | No
------------------------------------------------
W, V | 2 | No | Yes
------------------------------------------------
T, S | 3 | Any | Any
------------------------------------------------
Special symbols:
- I_{dims} (rank-two identify tensor, A.K.A. identity matrix)
Operations:
- _tprod(x, y) = numpy_backend.tensordot(x, y, axes=0) (tensor product)
- symm(x) = x + x.T
- antisymm(x) = x - x.T
------------------------------------------------
Input tensors:
{_get_type_symbols(input_types, input_state)}
------------------------------------------------
Scalar invariants:
"""
def _get_symbol_replacements(
names: Union[List[str], Tuple[str]],
types: Iterable[Union[TensorType, Tuple[TensorType, int]]],
symbol_map: List[List[str]],
) -> List[Tuple[str, str]]:
# names are returned by _infer_param_names() and are the same length as types
replacements = set()
for i, ttype in enumerate(types):
if not isinstance(ttype, TensorType):
ttype, j = ttype
tid = _get_symbol_id(ttype)
map_to = symbol_map[tid][j]
else:
tid = _get_symbol_id(ttype)
map_to = names[i]
candidate = symbol_map[tid][0]
if candidate != map_to:
replacements.add((candidate, map_to))
return replacements
def _replace_symbols(
names: Union[List[str], Tuple[str]],
types: Iterable[Union[TensorType, Tuple[TensorType, int]]],
symbol_map: List[List[str]],
line: str,
) -> str:
replacements = _get_symbol_replacements(names, types, symbol_map)
for replacement in replacements:
line = line.replace(*replacement).replace(
replacement[0].lower(), replacement[1].lower()
)
return line
def _format_invt_function(f, types, symbol_map, html):
param_names = _infer_param_names(f, types, symbol_map)
lines = _strip_function_header(inspect.getsourcelines(f)[0])
if len(lines) == 0:
return (
f"<code>{param_names} -> {param_names}</code>"
if html
else f"{param_names} -> {param_names}"
)
elif len(lines) == 1:
body = []
retline = lines[0]
else:
retline = -1
for i, line in enumerate(lines):
if "return" in line:
retline = i
break
retline, body = lines[retline:], lines[:retline]
retline = "\n".join(retline)
body = _format_body(body)
# retline = _replace_symbols(param_names, types, symbol_map,
# _format_return_line(retline.lstrip(' ')))
body = ",\n".join(body).rstrip(",\n")
func_descr = f"{param_names} -> "
replacement_str = "["
for r in inspect.signature(f).parameters:
replacement_str += "'" + r + "', "
replacement_str = replacement_str.rstrip(", ") + "]"
func_descr += replacement_str + " -> "
func_descr += _format_return_line(retline.lstrip(" "))
if len(lines) > 1:
func_descr += "where\n" + body
if html:
return "<code>" + func_descr + "</code>"
return func_descr
def _format_body(body):
return [x.lstrip(" ").rstrip("\n") for x in body]
def _get_num_retvals(f, input_types, backend):
inputs = [_untagged_value(t).get_array_like() for t in input_types]
return f(*inputs, backend).size
def _get_num_retvals_form_invt(f, input_types, backend):
inputs = [_untagged_value(t).get_array_like() for t in input_types]
return len(f(*inputs, backend))
def _format_return_line(line):
if "array(" in line:
idx = line.find("array(") + len("array(")
line = line.rstrip("\n").rstrip(")")[idx:] + "\n"
elif line.startswith("return "):
return line[len("return ") :]
return line
def _strip_function_header(f_lines):
return f_lines[2:]
def _get_symbol_id(ttype):
if ttype.order < 2:
return ttype.order
if ttype.order == 2:
return 2 if ttype.symmetric else 3
return 4
def _infer_param_names(f, input_types, symbol_map):
names = []
for ttype in input_types:
if isinstance(ttype, TensorType):
# if it's a pure TensorType, it must be the first of its type, so get the first symbol
name = (
symbol_map[_get_symbol_id(ttype)][0] if ttype.name == "" else ttype.name
)
names.append(name)
else:
ttype, i = ttype
name = (
symbol_map[_get_symbol_id(ttype)][i] if ttype.name == "" else ttype.name
)
names.append(name)
return names
def _get_type_symbols(ts, state):
def _handle_rank_0(t):
state.n_rank_0 += 1
if state.n_rank_0 == 1:
if t.name == "":
return "x"
else:
state.n_rank_0_n += 1
return t.name
elif state.n_rank_0 == 2:
if t.name == "":
return "y"
else:
state.n_rank_0_n += 1
return t.name
elif state.n_rank_0 == 3:
if t.name == "":
return "z"
else:
state.n_rank_0_n += 1
return t.name
else:
raise ValueError(
"Currently can only represent 3 scalars in the input types!"
)
def _handle_rank_1(t):
# global n_rank_1
if state.n_rank_1 == 0:
state.n_rank_1 = 1
if t.name == "":
return "v"
else:
state.n_rank_1_n += 1
return t.name
elif state.n_rank_1 == 1:
state.n_rank_1 = 2
if t.name == "":
return "u"
else:
state.n_rank_1_n += 1
return t.name
else:
raise ValueError(
"Currently can only represent 2 vectors in the input types!"
)
def _handle_rank_2(t):
# global n_rank_2_s
if t.symmetric:
state.n_rank_2_s += 1
if state.n_rank_2_s == 1:
if t.name == "":
return "A"
else:
state.n_rank_2_s_n += 1
return t.name
elif state.n_rank_2_s == 2:
if t.name == "":
return "B"
else:
state.n_rank_2_s_n += 1
return t.name
elif state.n_rank_2_s == 3:
if t.name == "":
return "C"
else:
state.n_rank_2_s_n += 1
return t.name
else:
raise ValueError(
"Currently can only represent 3 symmetric rank-two tensors in the input types!"
)
else:
state.n_rank_2_a += 1
if state.n_rank_2_a == 1:
if t.name == "":
return "W"
else:
state.n_rank_2_a_n += 1
return t.name
elif state.n_rank_2_a == 2:
if t.name == "":
return "V"
else:
state.n_rank_2_a_n += 1
return t.name
else:
raise ValueError(
"Currently can only represent 2 antisymmetric rank-two tensors in the input types!"
)
def _handle_rank_3(t):
state.n_rank_3 += 1
if state.n_rank_3 == 1:
if t.name == "":
return "T"
else:
state.n_rank_3_n += 1
return t.name
elif state.n_rank_3 == 2:
if t.name == "":
return "S"
else:
state.n_rank_3_n += 1
return t.name
else:
raise ValueError(
"Currently can only represent 2 rank-three tensors in the input types!"
)
handler_map = {
0: lambda x: _handle_rank_0(x),
1: lambda x: _handle_rank_1(x),
2: lambda x: _handle_rank_2(x),
3: lambda x: _handle_rank_3(x),
}
dispatcher = lambda t: handler_map[t.order](t)
return [dispatcher(_untagged_value(t)) for t in ts]
# tags chains of the same value T (e.g. T,T,T,T,... becomes T,(T,1),(T,2),...)
def _tag_value_chains(indexable):
val = list(indexable)
N = len(indexable)
i = 1
while i < N:
if val[i - 1] == val[i]:
# chain starts here
T = val[i]
n = 1
while val[i] == T:
val[i] = (T, n)
n += 1
i += 1
if i >= N:
break
else:
i += 1
return tuple(val)
def _get_sorted_and_indices(l, reverse=True):
if len(l) == 0:
return (), ()
srtd = sorted(enumerate(l), key=itemgetter(1), reverse=reverse)
return tuple(x[0] for x in srtd), tuple(x[1] for x in srtd)
def _nameless(x):
return TensorType(x.order, x.shape, x.symmetric, x.antisymmetric)
def _untagged_value(val, nameless=False):
if isinstance(val, TensorType):
return _nameless(val) if nameless else val
return _nameless(val[0]) if nameless else val[0]
def _untagged_tuple(tpl, nameless=False):
return tuple(_untagged_value(x, nameless) for x in tpl)
def _1darr(x, backend):
return backend.atleast_1d(x)
def _get_invariant_functions(dims, group, input_types, output_type):
missing_dims_err_str = f"The number of dimensions you passed to get_scalar_invariant_function through the info parameter ({dims}) is not supported. Supported numbers of dimensions are currently: (2,3)"
missing_group_err_str = f"The symmetry group you passed in (of type {type(group)}) is not among the currently supported set of symmetry groups (O(2),SO(2),O(3),SO(3))"
missing_num_inputs_err_str = f"The number of inputs you passed ({len(input_types)}) is not among the currently-supported set of input sizes (1,2,3,4)."
missing_inputs_err_str = f"The input types you passed ({input_types}) are not among the currently-supported set of input types in {dims} dimensions under the symmetry group {group}."
missing_outputs_err_str = f"The output type you passed ({output_type}) along with the combination of input types you passed ({input_types}) does not have a form-invariant function implemented at the time, or it does not exist."
try:
group_table = _get_item(
_scalar_invariant_function_table, dims, missing_dims_err_str
)
inputs_table = _get_item(group_table, group, missing_group_err_str)
inputs_table = _get_item(
inputs_table, len(input_types), missing_num_inputs_err_str
)
scalar_func, outputs_table = _get_item(
inputs_table, input_types, missing_inputs_err_str
)
except ValueError as v:
raise v
try:
# get form-invariants
output_basis_func = _get_item(
outputs_table, output_type, missing_outputs_err_str
)
return scalar_func, output_basis_func
except ValueError as v:
return scalar_func, None
def _get_item(table, key, err_msg):
val = None
try:
val = table[key]
except KeyError:
raise ValueError(err_msg)
return val
[docs]def register_invariant_functions(
info: InvariantInfo,
scalar_invariant_func,
form_invariant_func,
overwrite_existing=False,
jit=False,
backend=None,
):
"""Register a scalar and form-invariant computing function for a given
InvariantInfo.
:param info: an InvariantInfo containing the relevant information about the
inputs and outputs of functions with this symmetry.
:type info: InvariantInfo
:param scalar_invariant_func: a function that returns a single jax.numpy.ndarray
contaning the scalar invariants for the inputs.
:type scalar_invariant_func: Callable
:param form_invariant_func: a function that returns a Python list of jax.numpy.ndarray
instances representing the form-invariants for the inputs
:type form_invariant_func: Callable
:param overwrite_existing: if True, and the InvariantInfo you pass describes an
existing set of invariants, replace those with your function. You should
NEVER set this to True unless you really know what you're doing. If you want to
overwrite one function but not the other (e.g. insert a form-invariant for a scenario
where the scalar invariant already exists), you can also pass a pair of bools, one for
the scalar invariant function and one for the form invariant function. defaults to False
:type overwrite_existing: Union[bool, Tuple[bool, bool]], optional
:param jit: if True, call backend.jit() on scalar_invariant_func() and form_invariant_func() if it exists, defaults to False
:type jit: bool, optional
:param backend: Which numpy backend to use? Must be one of 'jax', 'torch',
or 'numpy'. Defaults to whichever is the CRIKit default numpy backend
:type backend: str, optional
:return: None, makes your functions available to get_invariant_functions()
"""
if backend is None:
backend = get_default_backend()
else:
backend = get_backend(backend)
if jit:
scalar_invariant_func = backend.jit(scalar_invariant_func)
form_invariant_func = backend.jit(form_invariant_func)
dims = info.spatial_dims
group, input_types = info.get_group_symbol(sanitize_input_types=True)
output_type = info.output_type
input_types = tuple(sorted(input_types, reverse=True))
global _scalar_invariant_function_table
if dims in _scalar_invariant_function_table:
group_table = _scalar_invariant_function_table[dims]
else:
group_table = {group: {}}
_scalar_invariant_function_table[dims] = group_table
if group in group_table:
n_input_table = group_table[group]
else:
n_input_table = {N: {}}
group_table[group] = n_input_table
N = len(input_types)
if N in n_input_table:
inputs_table = n_input_table[N]
else:
inputs_table = {tuple(input_types): [scalar_invariant_func, {}]}
n_input_table[N] = inputs_table
if isinstance(overwrite_existing, (int, bool)):
overwrite_existing_scalar = overwrite_existing_form = overwrite_existing
else:
# if it's not a single bool/int, it must be a pair of them (or else we need to
# throw an exception anyway, might as well let it be the default one from unpacking
# a non-iterable object
overwrite_existing_scalar, overwrite_existing_form = overwrite_existing
if input_types in inputs_table:
if not overwrite_existing_scalar:
logger.warning(
f"Scalar invariants for the inputs combination {input_types} already exists! Skipping this replacement for now. To overwrite this function, pass overwrite_existing=True to register_invariant_functions(). ONLY do this if you know what you are doing!"
)
else:
extant, outputs_table = inputs_table[input_types]
inputs_table[input_types] = (scalar_invariant_func, outputs_table)
else:
outputs_table = {}
inputs_table[input_types] = (scalar_invariant_func, outputs_table)
if output_type in outputs_table:
if not overwrite_existing_form:
logger.warning(
f"Form-invariant functions for this inputs and outputs combination (inputs {input_types}, output {output_type}) already exist! Skipping this replacement for now. To overwrite this function, pass overwrite_existing=True to register_invariant_functions(). ONLY do this if you know what you're doing!"
)
else:
outputs_table[output_type] = form_invariant_func
else:
outputs_table[output_type] = form_invariant_func
"""
Because Python's parsing rules are kinda weird when dictionaries are involved, all of the invariant-
calculating functions have to be first. If you're reading this for the first time, skip down to the
tables below first.
"""
def _2d_hmt_1_r3_o_r3(T, backend):
return [T, T @ levi_civita(2, backend)]
def _2d_hmt_1sr2_o_s2(A, backend):
return [A, A @ levi_civita(2, backend) - levi_civita(2, backend) @ A]
def _2d_hmt_1_vec_o_vec(v, backend):
return [v, levi_civita(2, backend) @ v]
def _2d_hmt_1_vec_o_s2(v, backend):
ev = levi_civita(2, backend) @ v
return [_tprod(v, v, backend), _tprod(v, ev, backend) + _tprod(ev, v, backend)]
def _2d_hmt_1_vec_o_s3(v, backend):
return [_tprod(v, _tprod(v, v, backend), backend)]
def _2d_hmt_1_vec_o_r3(v, backend):
vv = _tprod(v, v)
ev = levi_civita(2, backend) @ v
return [
_tprod(v, vv, backend),
_tprod(vv, ev, backend),
_tprod(v, backend.eye(2), backend),
_tprod(ev, backend.eye(2), backend),
]
def _2d_iso_1_vec_o_vec(v, backend):
return [v]
def _2d_iso_1_vec_o_s2(v, backend):
return [_tprod(v, v, backend)]
def _2d_iso_1_vec_o_r3(v, backend):
return [_tprod(v, _tprod(v, v, backend), backend)] + tbrace(
_tprod(v, backend.eye(2), backend), backend
)
def _2d_iso_1r3_o_3(T, backend):
return [T]
def _2d_iso_1sr2_o_s2(A, backend):
return [A]
def _2d_iso_1ar2_o_a2(W, backend):
return [W]
def _2d_iso_2sr2_o_a2(A, B, backend):
return [A @ B - B @ A]
def _2d_iso_1sr2_1ar2_o_s2(A, W, backend):
return [A @ W - W @ A]
def _2d_iso_1sr2_1_vec_o_vec(A, v, backend):
return [A @ v]
def _2d_iso_1sr2_1_vec_o_a2(A, v, backend):
av = A @ v
return [_tprod(v, av, backend) - A @ _tprod(v, v, backend)]
def _2d_iso_1sr2_1_vec_o_s3(A, v, backend):
return tbrace(_tprod(A @ v, backend.eye(2), backend), backend)
def _2d_iso_1ar2_1_vec_o_s3(W, v, backend):
return [symm_q3(_tprod(v, _tprod(v, W @ v, backend), backend), backend)]
def _2d_iso_2_vec_o_s2(v, u, backend):
return [_tprod(v, u, backend) + _tprod(u, v, backend)]
def _2d_iso_2_vec_o_a2(v, u, backend):
return [_tprod(v, u, backend) - _tprod(u, v, backend)]
def _2d_iso_2_vec_o_s3(v, u, backend):
return [symm_q3(_tprod(v, _tprod(v, u, backend), backend), backend)]
def _2d_iso_1r3_1_vec_o_vec(T, v, backend):
return [tv(T, v, backend)]
def _2d_iso_1r3_1_vec_o_s2(T, v, backend):
return [Tv(T, v, backend)]
def _2d_iso_1r3_1_vec_o_a2(T, v, backend):
ttv = tv(T, v)
return [_tprod(v, ttv, backend) - _tprod(ttv, v, backend)]
def _2d_iso_1r3_1sr2_o_vec(T, A, backend):
return tA(T, A, backend)
def _2d_iso_1r3_1sr2_o_s2(T, A, backend):
tta = tA(T, A)
return [_tprod(tta, tta, backend)]
def _2d_iso_1r3_1sr2_o_a2(T, A, backend):
tta = tA(T, A)
ata = A @ tta
return [_tprod(tta, ata, backend) - _tprod(ata, tta, backend)]
def _2d_iso_1r3_1sr2_o_3(T, A, backend):
tta = tA(T, A)
return [symm_q3(_tprod(A @ tta, A, backend), backend)] + tbrace(
_tprod(tta, backend.eye(2), backend), backend
)
def _2d_iso_2_r3_o_a2(T, S, backend):
return TinnerS(T, S, backend)
def _3d_hmt_1sr2_o_s2(A, backend):
return [A, A @ A]
def _3d_hmt_1ar2_o_s2(W, backend):
return [W @ W]
def _3d_hmt_1ar2_o_a2(W, backend):
return [W]
def _3d_hmt_1ar2_o_vec(W, backend):
return [axial_vector(W, backend)]
def _3d_hmt_1vec_o_s2(v, backend):
return [backend.tensordot(v, v, axes=0)]
def _3d_hmt_1vec_o_a2(v, backend):
return [backend.einsum("ijk,k -> ij", levi_civita(3, backend), v, optimize=True)]
def _3d_hmt_1vec_o_vec(v, backend):
return [v]
def _3d_hmt_2sr2_o_s2(A, B, backend):
AB = A @ B
BA = B @ A
return [AB + BA, A @ AB + BA @ A, AB @ B + B @ BA]
def _3d_hmt_2sr2_o_a2(A, B, backend):
AB = A @ B
BA = B @ A
BAA = BA @ A
BBA = B @ BA
return [
AB - BA,
A @ AB - BAA,
AB @ B - BBA,
A @ BAA - A @ A @ BA,
B @ AB @ B - BBA @ B,
]
def _3d_hmt_2sr2_o_vec(A, B, backend):
AB = A @ B
ABB = AB @ B
return [
axial_vector(AB, backend),
axial_vector(A @ AB, backend),
axial_vector(ABB, backend),
axial_vector(AB @ A @ A, backend),
axial_vector(B @ ABB, backend),
]
def _3d_hmt_1s1ar2_o_s2(A, W, backend):
AW = A @ W
WA = W @ A
W2 = W @ W
return [AW - WA, AW @ W + W @ WA, WA @ W2 - W2 @ AW, A @ AW - WA @ A]
def _3d_empty(backend):
return backend.array([])
def _2d_empty(backend):
return backend.array([])
def _2d_idf(backend):
return backend.eye(2)
def _3d_idf(backend):
return backend.eye(3)
def _3d_hmt_1s1ar2_o_a2(A, W, backend):
AW = A @ W
WA = W @ A
return [AW + WA, AW @ W - W @ WA]
def _3d_hmt_1s1ar2_o_vec(A, W, backend):
AB = A @ B
return [
axial_vector(AB, backend),
axial_vector(A @ AB, backend),
axial_vector(AB @ B, backend),
axial_vector(AB @ A @ A, backend),
axial_vector(B @ AB @ B, backend),
]
def _3d_hmt_2ar2_o_s2(W, V, backend):
WV = W @ V
VW = V @ W
return [WV + VW, W @ WV - VW @ W, WV @ V - V @ VW]
def _3d_hmt_2ar2_o_a2(W, V, backend):
return [W @ V - V @ W]
def _3d_hmt_2ar2_o_vec(W, V, backend):
return [axial_vector(W @ V, backend)]
def _3d_hmt_1ar21vec_o_s2(W, v, backend):
ev = backend.einsum("ijk,k -> ij", levi_civita(3, backend), v)
wev = W @ ev
return [
symm(_tprod(v, W @ v, backend), backend),
symm(wev, backend),
symm(W @ wev, backend),
]
def _3d_hmt_1ar21vec_o_a2(W, v, backend):
return [
antisymm(W @ backend.einsum("ijk,k -> ij", levi_civita(3, backend), v), backend)
]
def _3d_hmt_1ar21vec_o_vec(W, v, backend):
return [W @ v]
def _3d_hmt_1sr21vec_o_s2(A, v, backend):
return [
symm(_tprod(v, A @ v, backend), backend),
symm(A @ backend.einsum("ijk,k -> ij", levi_civita(3, backend), v), backend),
symm(_tprod(v, backend.cross(v, A @ v), backend), backend),
]
def _3d_hmt_1sr21vec_o_a2(A, v, backend):
return [
antisymm(
A @ backend.einsum("ijk,k -> ij", levi_civita(3, backend), v), backend
),
antisymm(_tprod(v, A @ v, backend), backend),
]
def _3d_hmt_1sr21vec_o_vec(A, v, backend):
Av = A @ v
return [Av, backend.cross(v, Av)]
def _3d_iso_1sr21vec_o_s2(A, v, backend):
av = A @ v
return [
symm(_tprod(v, av, backend), backend),
symm(_tprod(v, A @ av, backend), backend),
]
def _3d_iso_1sr21vec_o_a2(A, v, backend):
av = A @ v
aav = A @ av
return [
antisymm(_tprod(v, av, backend), backend),
antisymm(_tprod(v, aav, backend), backend),
antisymm(_tprod(av, aav, backend), backend),
]
def _3d_iso_1sr21vec_o_vec(A, v, backend):
av = A @ v
return [av, A @ av]
def _3d_iso_2vec_o_s2(v, u, backend):
return [symm(_tprod(v, u, backend), backend)]
def _3d_iso_2vec_o_a2(v, u, backend):
return [antisymm(_tprod(v, u, backend), backend)]
def _3d_hmt_2vec_o_s2(v, u, backend):
return [symm(_tprod(v, u, backend), backend)]
def _3d_hmt_2vec_o_a2(v, u, backend):
return [antisymm(_tprod(v, u, backend), backend)]
def _3d_hmt_2vec_o_vec(v, u, backend):
return [backend.cross(v, u)]
def _3d_hmt_3sr2_o_s2(A, B, C, backend):
AB = A @ B
BC = B @ C
AC = A @ C
A2 = A @ A
B2 = B @ B
C2 = C @ C
ABC = AB @ C
return [ABC, A @ ABC, B2 @ AC, C2 @ AB, A2 @ B2 @ C, B2 @ C2 @ A, C2 @ A2 @ B]
def _3d_hmt_3sr2_o_a2(A, B, C, backend):
AB = A @ B
BC = B @ C
AC = A @ C
CB = C @ B
return [AB @ C - CB @ A + BC @ A - A @ CB + C @ AB - B @ AC]
def _3d_hmt_3sr2_o_vec(A, B, C, backend):
BC = B @ C
return [axial_vector(A @ BC + BC @ A + C @ A @ B, backend)]
def _tprod(x, y, backend):
return backend.tensordot(x, y, axes=0)
def _3d_iso_2sr21vec_o_a2(A, B, v, backend):
return [
antisymm(_tprod(A @ v, B @ v, backend), backend)
+ antisymm(_tprod(v, antisymm(A @ B) @ v, backend), backend)
]
def _3d_iso_2sr21vec_o_vec(A, B, v, backend):
av = A @ v
bv = B @ v
return [A @ bv - B @ av]
def _3d_iso_1s1ar21vec_o_a2(A, W, v, backend):
wv = W @ v
av = A @ v
return [A @ wv + W @ av]
def _3d_iso_2ar21vec_o_vec(W, V, v, backend):
vv = V @ v
wv = W @ v
return [W @ vv - V @ wv]
def _3d_iso_1ar22vec_o_s2(W, v, u, backend):
vtu = antisymm(_tprod(v, u, backend), backend)
return [W @ vtu + vtu @ W]
def _3d_iso_1ar22vec_o_a2(W, v, u, backend):
vtu = antisymm(_tprod(v, u, backend), backend)
return [W @ vtu - vtu @ W]
def _3d_iso_1sr22vec_o_s2(A, v, u, backend):
vtu = antisymm(_tprod(v, u, backend), backend)
return [A @ vtu - vtu @ A]
def _3d_iso_1sr22vec_o_a2(A, v, u, backend):
vtu = antisymm(_tprod(v, u, backend), backend)
return [A @ vtu + vtu @ A]
def _3d_isotropic_2s_2_vec(A, B, v, u, backend):
return backend.inner(v, commutator_action(A, B, u))
def _3d_isotropic_1s_1a_2_vec(A, W, v, u, backend):
return backend.inner(v, anticommutator_action(A, W, u))
def _3d_isotropic_2a_2_vec(W, V, v, u, backend):
return backend.inner(v, commutator_action(W, V, u))
def _3d_hemitropic_1_vec(v, backend):
return backend.inner(v, v)
def _3d_isotropic_1s_rank_2_1_vec(A, v, backend):
Av = A @ v
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, Av)),
backend.atleast_1d(backend.inner(v, A @ Av)),
]
)
def _3d_isotropic_1a_rank_2_1_vec(W, v, backend):
return backend.inner(v, W @ (W @ v))
def _3d_isotropic_2_vec(v, u, backend):
return backend.inner(v, u)
def _3d_isotropic_2s_rank_2_1_vec(A, B, v, backend):
bv = B @ v
return backend.inner(v, A @ bv)
def _3d_isotropic_1s_1a_rank_2_1_vec(A, W, v, backend):
Wv = W @ v
AWv = A @ Wv
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, AWv)),
backend.atleast_1d(backend.inner(v, A @ AWv)),
backend.atleast_1d(backend.inner(v, W @ (A @ (W @ Wv)))),
]
)
def _3d_isotropic_2a_rank_2_1_vec(W, V, v, backend):
Vv = V @ v
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, W @ Vv)),
backend.atleast_1d(backend.inner(v, W @ (W @ Vv))),
backend.atleast_1d(backend.inner(v, W @ (V @ Vv))),
]
)
def _3d_isotropic_1s_rank_2_2_vec(A, v, u, backend):
Au = A @ u
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, Au)),
backend.atleast_1d(backend.inner(v, A @ Au)),
]
)
def _3d_isotropic_1a_rank_2_2_vec(W, v, u, backend):
Wu = W @ u
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, Wu)),
backend.atleast_1d(backend.inner(v, W @ Wu)),
]
)
def _3d_hemitropic_1s_rank_2(A, backend):
A2 = A @ A
return backend.concatenate(
[
backend.atleast_1d(backend.trace(A)),
backend.atleast_1d(backend.trace(A2)),
backend.atleast_1d(backend.trace(A @ A2)),
]
)
def _3d_hemitropic_1a_rank_2(W, backend):
return backend.trace(W @ W)
def _3d_hemitropic_1_vec(v, backend):
return backend.inner(v, v)
def _3d_hemitropic_2s_rank_2(A, B, backend):
A2 = A @ A
B2 = B @ B
return backend.concatenate(
[
backend.atleast_1d(backend.trace(A @ B)),
backend.atleast_1d(backend.trace(A2 @ B)),
backend.atleast_1d(backend.trace(A @ B2)),
backend.atleast_1d(backend.trace(A2 @ B2)),
]
)
def _3d_hemitropic_2a_rank_2(W, V, backend):
return backend.trace(W @ V)
def _3d_hemitropic_1s_1a_rank_2(A, W, backend):
A2 = A @ A
W2 = W @ W
A2W2 = A2 @ W2
return backend.concatenate(
[
backend.atleast_1d(backend.trace(A @ W2)),
backend.atleast_1d(backend.trace(A2W2)),
backend.atleast_1d(backend.trace(A2W2 @ A @ W)),
]
)
def _3d_hemitropic_1s_rank_2_1_vec(A, v, backend):
av = A @ v
aav = A @ av
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, A @ v)),
backend.atleast_1d(backend.inner(v, aav)),
backend.atleast_1d(scalar_triple_prod(v, av, aav, backend)),
]
)
def _3d_hemitropic_1a_rank_2_1_vec(W, v, backend):
return backend.inner(v, axial_vector(W, backend))
def _3d_hemitropic_2_vec(v, u, backend):
return backend.inner(v, u)
def _3d_hemitropic_3s_rank_2(A, B, C, backend):
return backend.trace(A @ B @ C)
def _3d_hemitropic_2s_1a_rank_2(A, B, W, backend):
AB = A @ B
BW = B @ W
return backend.concatenate(
[
backend.atleast_1d(backend.trace(AB @ W)),
backend.atleast_1d(backend.trace(A @ AB @ W)),
backend.atleast_1d(backend.trace(AB @ BW)),
backend.atleast_1d(backend.trace(A @ W @ W @ BW)),
]
)
def _3d_hemitropic_1s_2a_rank_2(A, W, V, backend):
AW = A @ W
AWV = AW @ V
return backend.concatenate(
[
backend.atleast_1d(backend.trace(AWV)),
backend.atleast_1d(backend.trace(AW @ W @ V)),
backend.atleast_1d(backend.trace(AWV @ V)),
]
)
def _3d_hemitropic_3a_rank_2(W, V, U, backend):
return backend.trace(W @ V @ U)
def _3d_hemitropic_2s_rank_2_1_vec(A, B, v, backend):
AB = A @ B
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, axial_vector(AB, backend))),
backend.atleast_1d(backend.inner(v, axial_vector(A @ AB, backend))),
backend.atleast_1d(backend.inner(v, axial_vector(AB @ B, backend))),
backend.atleast_1d(scalar_triple_prod(v, A @ v, B @ v, backend)),
]
)
def _3d_hemitropic_1s_1a_rank_2_1_vec(A, W, v, backend):
AW = A @ W
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, AW @ v)),
backend.atleast_1d(backend.inner(v, axial_vector(AW))),
backend.atleast_1d(backend.inner(v, axial_vector(AW @ W))),
]
)
def _3d_hemitropic_2a_rank_2_1_vec(W, V, v, backend):
return backend.inner(v, axial_vector(W @ V, backend))
def _3d_hemitropic_1s_rank_2_2_vec(A, v, u, backend):
au = A @ u
return backend.concatenate(
[
backend.atleast_1d(backend.inner(v, au)),
backend.atleast_1d(scalar_triple_prod(v, u, A @ v, backend)),
backend.atleast_1d(scalar_triple_prod(v, u, au, backend)),
]
)
def _3d_hemitropic_1a_rank_2_2_vec(W, v, u, backend):
return backend.inner(v, W @ u)
def _3d_hemitropic_3_vec(v, u, w, backend):
return scalar_triple_prod(v, u, w, backend)
"""
2-d versions of the above (and more) where applicable
"""
def _2d_isotropic_1s_rank_2(A, backend):
return backend.concatenate(
[backend.atleast_1d(backend.trace(A)), backend.atleast_1d(backend.trace(A @ A))]
)
def _2d_isotropic_1a_rank_2(W, backend):
return backend.trace(W @ W)
def _2d_isotropic_1_vec(v, backend):
return backend.inner(v, v)
def _2d_isotropic_1_rank_3(T, backend):
return TcontrS(T, T, backend)
def _2d_isotropic_2s_rank_2(A, B, backend):
return backend.trace(A @ B)
def _2d_isotropic_2a_rank_2(W, V, backend):
return backend.trace(W @ V)
def _2d_isotropic_1s_rank_2_1_vec(A, v, backend):
return backend.inner(v, A @ v)
def _2d_isotropic_2_vec(v, u, backend):
return backend.inner(v, u)
def _2d_isotropic_2_rank_3(T, S, backend):
return TcontrS(T, S, backend)
def _2d_isotropic_1_rank_3_1_vec(T, v, backend):
return backend.inner(v, tv(T, v, backend))
def _2d_isotropic_1_rank_3_1s_rank_2(T, A, backend):
ta = tA(T, A)
return backend.inner(tA, A @ tA)
def _2d_isotropic_1s_1a_rank_2(A, W, backend):
return backend.array([])
def _2d_isotropic_1a_rank_2_1_vec(W, v, backend):
return backend.array([])
def _2d_hemitropic_1a_rank_2(W, backend):
return backend.trace(levi_civita(2, backend) @ W)
"""
These are tables mapping output tensor types to functions that compute the form-invariants that
make up the output basis.
"""
_2d_hemitropic_1s_rank_2_outputs = {
TensorType(2, (2, 2), True, False): _2d_hmt_1sr2_o_s2,
}
_2d_hemitropic_1a_rank_2_outputs = {}
_2d_hemitropic_1_vec_outputs = {
TensorType(1, (2,), False, False): _2d_hmt_1_vec_o_vec,
TensorType(
2,
(
2,
2,
),
True,
False,
): _2d_hmt_1_vec_o_s2,
TensorType(3, (2, 2, 2), True, False): _2d_hmt_1_vec_o_s3,
TensorType(3, (2, 2, 2), False, False): _2d_hmt_1_vec_o_r3,
}
_2d_hemitropic_1_rank_3_outputs = {
TensorType(3, (2, 2, 2), False, False): _2d_hmt_1_r3_o_r3
}
_2d_isotropic_2s_rank_2_outputs = {
TensorType(2, (2, 2), False, True): _2d_iso_2sr2_o_a2,
}
_2d_isotropic_2a_rank_2_outputs = {}
_2d_isotropic_1s_1a_rank_2_outputs = {
TensorType(2, (2, 2), True, False): _2d_iso_1sr2_1ar2_o_s2,
}
_2d_isotropic_1s_rank_2_1_vec_outputs = {
TensorType(1, (2,), False, False): _2d_iso_1sr2_1_vec_o_vec,
TensorType(2, (2, 2), False, True): _2d_iso_1sr2_1_vec_o_a2,
TensorType(3, (2, 2, 2), True, False): _2d_iso_1sr2_1_vec_o_s3,
}
_2d_isotropic_1a_rank_2_1_vec_outputs = {
TensorType(3, (2, 2, 2), True, False): _2d_iso_1ar2_1_vec_o_s3,
}
_2d_isotropic_2_vec_outputs = {
TensorType(2, (2, 2), True, False): _2d_iso_2_vec_o_s2,
TensorType(2, (2, 2), False, True): _2d_iso_2_vec_o_a2,
TensorType(3, (2, 2, 2), True, False): _2d_iso_2_vec_o_s3,
}
_2d_isotropic_2_rank_3_outputs = {
TensorType(2, (2, 2), False, True): _2d_iso_2_r3_o_a2,
}
_2d_isotropic_1_rank_3_1_vec_outputs = {
TensorType(1, (2,), False, False): _2d_iso_1r3_1_vec_o_vec,
TensorType(2, (2, 2), True, False): _2d_iso_1r3_1_vec_o_s2,
TensorType(2, (2, 2), False, True): _2d_iso_1r3_1_vec_o_a2,
}
_2d_isotropic_1_rank_3_1s_rank_2_outputs = {
TensorType(1, (2,), False, False): _2d_iso_1r3_1sr2_o_vec,
TensorType(2, (2, 2), True, False): _2d_iso_1r3_1sr2_o_s2,
TensorType(2, (2, 2), False, True): _2d_iso_1r3_1sr2_o_a2,
TensorType(3, (2, 2, 2), False, False): _2d_iso_1r3_1sr2_o_3,
}
_2d_isotropic_1_vec_outputs = {
TensorType(1, (2,), False, False): _2d_iso_1_vec_o_vec,
TensorType(2, (2, 2), True, False): _2d_iso_1_vec_o_s2,
TensorType(3, (2, 2, 2), False, False): _2d_iso_1_vec_o_r3,
}
_2d_isotropic_1_rank_3_outputs = {
TensorType(3, (2, 2, 2), False, False): _2d_iso_1r3_o_3,
}
_2d_isotropic_1s_rank_2_outputs = {
TensorType(2, (2, 2), True, False): _2d_iso_1sr2_o_s2,
# TensorType(2,(2,2),False,True) : ,
}
_2d_isotropic_1a_rank_2_outputs = {
TensorType(2, (2, 2), False, True): _2d_iso_1ar2_o_a2,
}
_3d_identity_outputs = {
TensorType(2, (3, 3), True, False): _3d_idf,
}
_2d_identity_outputs = {
TensorType(2, (2, 2), True, False): _2d_idf,
}
_3d_hemitropic_1s_rank_2_outputs = {
# rhs reads like "3-d hemitropic (1 symmetric rank 2) -> (symmetric rank 2) outputs function"
TensorType(2, (3, 3), True, False): _3d_hmt_1sr2_o_s2,
}
_3d_hemitropic_1a_rank_2_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_1ar2_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_1ar2_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_1ar2_o_vec,
}
_3d_hemitropic_1_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_1vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_1vec_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_1vec_o_vec,
}
_3d_isotropic_1s_rank_2_1_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_iso_1sr21vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_1sr21vec_o_a2,
TensorType(1, (3,), False, False): _3d_iso_1sr21vec_o_vec,
}
_3d_isotropic_2_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_iso_2vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_2vec_o_a2,
}
_3d_hemitropic_3s_rank_2_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_3sr2_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_3sr2_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_3sr2_o_vec,
}
_3d_hemitropic_3a_rank_2_outputs = {
# TensorType(2,(3,3),False,True) : _3d_hmt_3ar2_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_3ar2_o_a2,
# TensorType(1,(3,),False,False) : _3d_hmt_3ar2_o_vec,
}
_3d_hemitropic_2s_rank_2_1_vec_outputs = {
# TensorType(2,(3,3),False,True) : None,#_3d_hmt_2sr21vec_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_2sr21vec_o_a2,
# TensorType(1,(3,),False,False) : None,
}
_3d_hemitropic_1s_1a_rank_2_1_vec_outputs = {
# TensorType(2,(3,3),False,True) : None,#_3d_hmt_1s1ar21vec_o_s2,
# TensorType(2,(3,3),False,True) : None,#_3d_hmt_1s1ar21vec_o_a2,
# TensorType(1,(3,),False,False) : None#_3d_hmt_1s1ar21vec_o_vec,
}
_3d_hemitropic_2a_rank_2_1_vec_outputs = {
# TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_a2,
# TensorType(1,(3,),False,False) : _3d_hmt_2ar21vec_o_vec,
}
_3d_hemitropic_1s_rank_2_2_vec_outputs = {
# TensorType(2,(3,3),False,True) : _3d_hmt_1sr22vec_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_1sr22vec_o_a2,
# TensorType(1,(3,),False,False) : _3d_hmt_1sr22vec_o_vec,
}
_3d_hemitropic_1a_rank_2_2_vec_outputs = {
# TensorType(2,(3,3),False,True) : _3d_hmt_1ar22vec_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_1ar22vec_o_a2,
# TensorType(1,(3,),False,False) : _3d_hmt_1ar22vec_o_vec,
}
_3d_hemitropic_3_vec_outputs = {
# TensorType(2,(3,3),False,True) : _3d_hmt_3vec_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_3vec_o_a2,
# TensorType(1,(3,),False,False) : _3d_hmt_3vec_o_vec,
}
_3d_isotropic_2s_rank_2_1_vec_outputs = {
# TensorType(2,(3,3),False,True) : None,#_3d_hmt_2sr21vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_2sr21vec_o_a2,
TensorType(1, (3,), False, False): _3d_iso_2sr21vec_o_vec,
}
_3d_isotropic_1s_1a_rank_2_1_vec_outputs = {
# TensorType(2,(3,3),False,True) : None,#_3d_hmt_1s1ar21vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_1s1ar21vec_o_a2,
# TensorType(1,(3,),False,False) : None#_3d_hmt_1s1ar21vec_o_vec,
}
_3d_isotropic_2a_rank_2_1_vec_outputs = {
# TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_s2,
# TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_a2,
TensorType(1, (3,), False, False): _3d_iso_2ar21vec_o_vec,
}
_3d_isotropic_1a_rank_2_2_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_iso_1ar22vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_1ar22vec_o_a2,
# TensorType(1,(3,),False,False) : _3d_iso_1ar22vec_o_vec
}
_3d_isotropic_1s_rank_2_2_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_iso_1sr22vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_1sr22vec_o_a2,
# TensorType(1,(3,),False,False) : _3d_iso_1sr22vec_o_vec
}
"""
These tables map specific combinations of input tensor types to a pair containing a function
to compute the input scalar invariants and one of the output tables from above.
"""
_3d_hemitropic_three_input_table = {
(
TensorType(2, (3, 3), True, False),
TensorType(2, (3, 3), True, False),
TensorType(2, (3, 3), True, False),
): (_3d_hemitropic_3s_rank_2, _3d_hemitropic_3s_rank_2_outputs),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True)) : (_3d_hemitropic_2s_1a_rank_2,
# _3d_hemitropic_2s_1a_rank_2_outputs),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True),
# TensorType(2,(3,3),False,True)) : (_3d_hemitropic_1s_2a_rank_2,
# _3d_hemitropic_1s_2a_rank_2_outputs),
# (TensorType(2,(3,3),False,True),
# TensorType(2,(3,3),False,True),
# TensorType(2,(3,3),False,True)) : (_3d_hemitropic_3a_rank_2,
# _3d_hemitropic_3a_rank_2_outputs),
(
TensorType(2, (3, 3), True, False),
TensorType(2, (3, 3), True, False),
TensorType(1, (3,), False, False),
): (_3d_hemitropic_2s_rank_2_1_vec, _3d_hemitropic_2s_rank_2_1_vec_outputs),
(
TensorType(2, (3, 3), True, False),
TensorType(2, (3, 3), False, True),
TensorType(1, (3,), False, False),
): (_3d_hemitropic_1s_1a_rank_2_1_vec, _3d_hemitropic_1s_1a_rank_2_1_vec_outputs),
(
TensorType(2, (3, 3), False, True),
TensorType(2, (3, 3), False, True),
TensorType(1, (3,), False, False),
): (_3d_hemitropic_2a_rank_2_1_vec, _3d_hemitropic_2a_rank_2_1_vec_outputs),
(
TensorType(2, (3, 3), True, False),
TensorType(1, (3,), False, False),
TensorType(1, (3,), False, False),
): (_3d_hemitropic_1s_rank_2_2_vec, _3d_hemitropic_1s_rank_2_2_vec_outputs),
(
TensorType(2, (3, 3), False, True),
TensorType(1, (3,), False, False),
TensorType(1, (3,), False, False),
): (_3d_hemitropic_1a_rank_2_2_vec, _3d_hemitropic_1a_rank_2_2_vec_outputs),
(
TensorType(1, (3,), False, False),
TensorType(1, (3,), False, False),
TensorType(1, (3,), False, False),
): (_3d_hemitropic_3_vec, _3d_hemitropic_3_vec_outputs),
}
# future possible elements of the below table.
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),True,False)) : (None,None),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True)) : (None,None),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True),
# TensorType(2,(3,3),False,True)) : (None,None),
# (TensorType(2,(3,3),False,True),
# TensorType(2,(3,3),False,True),
# TensorType(2,(3,3),False,True)) : (None,None),
_3d_isotropic_three_input_table = {
(
TensorType(2, (3, 3), True, False),
TensorType(2, (3, 3), True, False),
TensorType(1, (3,), False, False),
): (_3d_isotropic_2s_rank_2_1_vec, _3d_isotropic_2s_rank_2_1_vec_outputs),
(
TensorType(2, (3, 3), True, False),
TensorType(2, (3, 3), False, True),
TensorType(1, (3,), False, False),
): (_3d_isotropic_1s_1a_rank_2_1_vec, _3d_isotropic_1s_1a_rank_2_1_vec_outputs),
(
TensorType(2, (3, 3), False, True),
TensorType(2, (3, 3), False, True),
TensorType(1, (3,), False, False),
): (_3d_isotropic_2a_rank_2_1_vec, _3d_isotropic_2a_rank_2_1_vec_outputs),
(
TensorType(2, (3, 3), True, False),
TensorType(1, (3,), False, False),
TensorType(1, (3,), False, False),
): (_3d_isotropic_1s_rank_2_2_vec, _3d_isotropic_1s_rank_2_2_vec_outputs),
(
TensorType(2, (3, 3), False, True),
TensorType(1, (3,), False, False),
TensorType(1, (3,), False, False),
): (_3d_isotropic_1a_rank_2_2_vec, _3d_isotropic_1a_rank_2_2_vec_outputs),
}
# _3d_isotropic_four_input_table = {
# '''(TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),True,False),
# TensorType(1,(3,),False,False),
# TensorType(1,(3,),False,False)) : (_3d_isotropic_2s_2_vec,
# _3d_isotropic_2s_2_vec_outputs),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True),
# TensorType(1,(3,),False,False),
# TensorType(1,(3,),True,False)) : (_3d_isotropic_1s_1a_2_vec,
# _3d_isotropic_1s_1a_2_vec_outputs),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True),
# TensorType(1,(3,),False,False),
# TensorType(1,(3,),False,False)) : (_3d_isotropic_2a_2_vec,
# _3d_isotropic_2a_2_vec_outputs),'''
# }
# _3d_hemitropic_four_input_table = {
# '''(TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),True,False),
# TensorType(1,(3,),False,False),
# TensorType(1,(3,),False,False)) : (None,None),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True),
# TensorType(1,(3,),False,False),
# TensorType(1,(3,),True,False)) : (None,None),
# (TensorType(2,(3,3),True,False),
# TensorType(2,(3,3),False,True),
# TensorType(1,(3,),False,False),
# TensorType(1,(3,),False,False)) : (None,None),'''
# }
_3d_isotropic_single_input_table = {
# maps input types to invariant calculation functions
# read "_3d_hemitroic_1s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the input is one symmetric rank-two tensor", and likewise with "_1a_" and antisymmetric
(TensorType(2, (3, 3), True, False),): (
_3d_hemitropic_1s_rank_2,
_3d_hemitropic_1s_rank_2_outputs,
),
(TensorType(2, (3, 3), False, True),): (
_3d_hemitropic_1a_rank_2,
_3d_hemitropic_1a_rank_2_outputs,
),
(TensorType(1, (3,), False, False),): (
_3d_hemitropic_1_vec,
_3d_hemitropic_1_vec_outputs,
),
(TensorType(0, (), False, False),): (
lambda x, backend: x,
{TensorType(0, (), False, False): lambda x, backend: 1.0},
),
}
_3d_isotropic_1_vec_outputs = {
# where they both exist, hemitropic and isotropic form-invariants are the same in this case
TensorType(2, (3, 3), True, False): _3d_hmt_1vec_o_s2,
TensorType(1, (3,), False, False): _3d_hmt_1vec_o_vec,
}
_3d_hemitropic_2s_rank_2_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_2sr2_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_2sr2_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_2sr2_o_vec,
}
_3d_hemitropic_1s_1a_rank_2_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_1s1ar2_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_1s1ar2_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_1s1ar2_o_vec,
}
_3d_hemitropic_2a_rank_2_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_2ar2_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_2ar2_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_2ar2_o_vec,
}
_3d_hemitropic_1s_rank_2_1_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_1sr21vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_1sr21vec_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_1sr21vec_o_vec,
}
_3d_hemitropic_1a_rank_2_1_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_1sr21vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_1sr21vec_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_1sr21vec_o_vec,
}
_3d_isotropic_1a_rank_2_1_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_iso_1sr21vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_iso_1sr21vec_o_a2,
TensorType(1, (3,), False, False): _3d_iso_1sr21vec_o_vec,
}
_3d_hemitropic_2_vec_outputs = {
TensorType(2, (3, 3), True, False): _3d_hmt_2vec_o_s2,
TensorType(2, (3, 3), False, True): _3d_hmt_2vec_o_a2,
TensorType(1, (3,), False, False): _3d_hmt_2vec_o_vec,
}
_3d_hemitropic_two_input_table = {
# read "_3d_hemitroic_2s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the inputs are two symmetric rank-two tensors"
(TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False)): (
_3d_hemitropic_2s_rank_2,
_3d_hemitropic_2s_rank_2_outputs,
),
(TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), False, True)): (
_3d_hemitropic_1s_1a_rank_2,
_3d_hemitropic_1s_1a_rank_2_outputs,
),
(TensorType(2, (3, 3), False, True), TensorType(2, (3, 3), False, True)): (
_3d_hemitropic_2a_rank_2,
_3d_hemitropic_2a_rank_2_outputs,
),
(TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False)): (
_3d_hemitropic_1s_rank_2_1_vec,
_3d_hemitropic_1s_rank_2_1_vec_outputs,
),
(TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False)): (
_3d_hemitropic_1a_rank_2_1_vec,
_3d_hemitropic_1a_rank_2_1_vec_outputs,
),
(TensorType(1, (3,), False, False), TensorType(1, (3,), False, False)): (
_3d_hemitropic_2_vec,
_3d_hemitropic_2_vec_outputs,
),
}
_2d_isotropic_two_input_table = {
(TensorType(2, (2, 2), True, False), TensorType(2, (2, 2), True, False)): (
_2d_isotropic_2s_rank_2,
_2d_isotropic_2s_rank_2_outputs,
),
(TensorType(2, (2, 2), False, True), TensorType(2, (2, 2), False, True)): (
_2d_isotropic_2a_rank_2,
_2d_isotropic_2a_rank_2_outputs,
),
(TensorType(2, (2, 2), True, False), TensorType(2, (2, 2), False, True)): (
_2d_isotropic_1s_1a_rank_2,
_2d_isotropic_1s_1a_rank_2_outputs,
),
(TensorType(2, (2, 2), True, False), TensorType(1, (2,), False, False)): (
_2d_isotropic_1s_rank_2_1_vec,
_2d_isotropic_1s_rank_2_1_vec_outputs,
),
(TensorType(2, (2, 2), False, True), TensorType(1, (2,), False, False)): (
_2d_isotropic_1a_rank_2_1_vec,
_2d_isotropic_1a_rank_2_1_vec_outputs,
),
(TensorType(1, (2,), False, False), TensorType(1, (2,), False, False)): (
_2d_isotropic_2_vec,
_2d_isotropic_2_vec_outputs,
),
(TensorType(3, (2, 2, 2), False, False), TensorType(3, (2, 2, 2), False, False)): (
_2d_isotropic_2_rank_3,
_2d_isotropic_2_rank_3_outputs,
),
(TensorType(3, (2, 2, 2), False, False), TensorType(1, (2,), False, False)): (
_2d_isotropic_1_rank_3_1_vec,
_2d_isotropic_1_rank_3_1_vec_outputs,
),
(TensorType(3, (2, 2, 2), False, False), TensorType(2, (2, 2), True, False)): (
_2d_isotropic_1_rank_3_1s_rank_2,
_2d_isotropic_1_rank_3_1s_rank_2_outputs,
),
}
_3d_isotropic_two_input_table = {
# read "_3d_hemitroic_2s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the inputs are two symmetric rank-two tensors"
(TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False)): (
_3d_hemitropic_2s_rank_2,
_3d_hemitropic_2s_rank_2_outputs,
),
(TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), False, True)): (
_3d_hemitropic_1s_1a_rank_2,
_3d_hemitropic_1s_1a_rank_2_outputs,
),
(TensorType(2, (3, 3), False, True), TensorType(2, (3, 3), False, True)): (
_3d_hemitropic_2a_rank_2,
_3d_hemitropic_2a_rank_2_outputs,
),
(TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False)): (
_3d_isotropic_1s_rank_2_1_vec,
_3d_isotropic_1s_rank_2_1_vec_outputs,
),
(TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False)): (
_3d_isotropic_1a_rank_2_1_vec,
_3d_isotropic_1a_rank_2_1_vec_outputs,
),
(TensorType(1, (3,), False, False), TensorType(1, (3,), False, False)): (
_3d_isotropic_2_vec,
_3d_isotropic_2_vec_outputs,
),
}
_3d_hemitropic_single_input_table = {
# maps input types to invariant calculation functions
# read "_3d_hemitroic_1s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the input is one symmetric rank-two tensor", and likewise with "_1a_" and antisymmetric
(TensorType(2, (3, 3), True, False),): (
_3d_hemitropic_1s_rank_2,
_3d_hemitropic_1s_rank_2_outputs,
),
(TensorType(2, (3, 3), False, True),): (
_3d_hemitropic_1a_rank_2,
_3d_hemitropic_1a_rank_2_outputs,
),
(TensorType(1, (3,), False, False),): (
_3d_hemitropic_1_vec,
_3d_hemitropic_1_vec_outputs,
),
(TensorType(0, (), False, False),): (
lambda x, backend: x,
{TensorType(0, (), False, False): lambda x, backend: 1.0},
),
}
_2d_isotropic_single_input_table = {
(TensorType(2, (2, 2), True, False),): (
_2d_isotropic_1s_rank_2,
_2d_isotropic_1s_rank_2_outputs,
),
(TensorType(2, (2, 2), False, True),): (
_2d_isotropic_1a_rank_2,
_2d_isotropic_1a_rank_2_outputs,
),
(TensorType(1, (2,), False, False),): (
_2d_isotropic_1_vec,
_2d_isotropic_1_vec_outputs,
),
# doesn't matter if the rank-3 is totally symmetric/antisymmetric or not, they're all the same
(TensorType(3, (2, 2, 2), False, False),): (
_2d_isotropic_1_rank_3,
_2d_isotropic_1_rank_3_outputs,
),
(TensorType(3, (2, 2, 2), True, False),): (
_2d_isotropic_1_rank_3,
_2d_isotropic_1_rank_3_outputs,
),
(TensorType(3, (2, 2, 2), False, True),): (
_2d_isotropic_1_rank_3,
_2d_isotropic_1_rank_3_outputs,
),
(TensorType(0, (), False, False),): (
lambda x, backend: x,
{TensorType(0, (), False, False): lambda x, backend: [1.0]},
),
}
_2d_hemitropic_single_input_table = {
# some of the entries are the same for hemi vs isotropic
(TensorType(2, (2, 2), True, False),): (
_2d_isotropic_1s_rank_2,
_2d_hemitropic_1s_rank_2_outputs,
),
(TensorType(2, (2, 2), False, True),): (
_2d_hemitropic_1a_rank_2,
_2d_hemitropic_1a_rank_2_outputs,
),
(TensorType(1, (2,), False, False),): (
_2d_isotropic_1_vec,
_2d_hemitropic_1_vec_outputs,
),
# doesn't matter if the rank-3 is totally symmetric/antisymmetric or not, they're all the same
(TensorType(3, (2, 2, 2), False, False),): (
_2d_isotropic_1_rank_3,
_2d_hemitropic_1_rank_3_outputs,
),
(TensorType(3, (2, 2, 2), True, False),): (
_2d_isotropic_1_rank_3,
_2d_hemitropic_1_rank_3_outputs,
),
(TensorType(3, (2, 2, 2), False, True),): (
_2d_isotropic_1_rank_3,
_2d_hemitropic_1_rank_3_outputs,
),
# a scalar
(TensorType(0, (), False, False),): (
lambda x, backend: x,
{TensorType(0, (), False, False): lambda x, backend: 1.0},
),
}
_3d_zero_input_table = {(): (_3d_empty, _3d_identity_outputs)}
_2d_zero_input_table = {(): (_2d_empty, _2d_identity_outputs)}
"""
These tables map the number of inputs to the table above containing scalar invariant calculation
functions and output form-invariant calculation tables
"""
_2d_isotropic_scalar_input_table = {
0: _2d_zero_input_table,
1: _2d_isotropic_single_input_table,
2: _2d_isotropic_two_input_table,
# 3 : _2d_isotropic_three_input_table,
# 4 : _2d_isotropic_four_input_table,
}
_2d_hemitropic_scalar_input_table = {
0: _2d_zero_input_table,
1: _2d_hemitropic_single_input_table,
# 2 : _2d_isotropic_two_input_table,
# 3 : _2d_isotropic_three_input_table,
# 4 : _2d_isotropic_four_input_table,
}
_3d_hemitropic_scalar_input_table = {
# maps number of inputs to (table mapping input types to invariant calculation functions)
0: _3d_zero_input_table,
1: _3d_hemitropic_single_input_table,
2: _3d_hemitropic_two_input_table,
3: _3d_hemitropic_three_input_table,
# 4 : _3d_hemitropic_four_input_table,
}
_3d_isotropic_scalar_input_table = {
# maps number of inputs to (table mapping input types to invariant calculation functions)
0: _3d_zero_input_table,
1: _3d_isotropic_single_input_table,
2: _3d_isotropic_two_input_table,
3: _3d_isotropic_three_input_table,
# 4 : _3d_isotropic_four_input_table,
}
"""
For each supported group, these tables map that group to a table (above), which maps the number of inputs to a third table mapping input types to invariant calculation functions and an output fucntion table
"""
_3d_invariant_group_table = {
# maps symmetry group to (table mapping number of inputs to (table mapping input types to invariant calculation functions))
"O(3)": _3d_isotropic_scalar_input_table,
"SO(3)": _3d_hemitropic_scalar_input_table,
}
_2d_invariant_group_table = {
# maps symmetry group to (table mapping number of inputs to (table mapping input types to invariant calculation functions))
"O(2)": _2d_isotropic_scalar_input_table,
"SO(2)": _2d_hemitropic_scalar_input_table,
}
"""
The first layer of the hash table onion. Maps the number of spatial dimensions to a group table
"""
_scalar_invariant_function_table = {
# maps spatial dimension to (table mapping symmetry group to (table mapping number of inputs to (table mapping input types to invariant calculation functions)))
2: _2d_invariant_group_table,
3: _3d_invariant_group_table,
}