Source code for crikit.invariants.invariants

import jax
from jax import jit
import sys
import importlib
from functools import partial
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple, List, Tuple, Optional, Union, Iterable, TypeVar
from crikit.logging import logger
from .utils import set_backend
from .utils import (
    get_backend,
    symm,
    antisymm,
    commutator_action,
    anticommutator_action,
    scalar_triple_prod,
    powerset,
    levi_civita,
    _eps_vec_action,
    axial_vector,
    _tprod,
    symm_q4,
    symm_q3,
    tA,
    tv,
    Tv,
    TinnerS,
    TcontrS,
    TW,
    tbrace,
    _3d_rotation_matrix,
    near,
    is_symm,
    is_antisymm,
)
from pyadjoint_utils.jax_adjoint import ndarray as jax_ndarray

try:
    from pyadjoint_utils.torch_adjoint import Tensor as torch_tensor_t
except ImportError:

    class torch_tensor_t:
        pass


from pyadjoint_utils.numpy_backend import (
    get_default_backend,
    get_backend,
    set_default_backend,
)
import inspect
from operator import itemgetter
import itertools

ndarray = TypeVar("ndarray")


[docs]class TensorType(NamedTuple): order: int = 0 shape: tuple = () symmetric: bool = False antisymmetric: bool = False name: str = ""
[docs] @staticmethod def make_scalar(name: str = ""): """Returns a TensorType representing a scalar :param name: The name of the scalar, defaults to '' :type name: str, optional :return: A TensorType representing a scalar :rtype: TensorType """ return TensorType(0, (), False, False, name)
[docs] @staticmethod def make_vector(spatial_dims: int, name: str = ""): """Returns a TensorType representing a vector :param spatial_dims: The number of spatial dimensions :type spatial_dims: int :param name: The name of the vector, defaults to '' :type name: str, optional :return: A TensorType representing a vector :rtype: TensorType """ return TensorType(1, (spatial_dims,), False, False, name)
[docs] @staticmethod def make_symmetric(order: int, spatial_dims: int, name: str = ""): """Returns a TensorType representing a symmetric tensor :param order: The order of the tensor :type order: int :param spatial_dims: How many spatial dimensions? :type spatial_dims: int :param name: The name of the tensor, defaults to '' :type name: str, optional :return: A TensorType representing a symmetric order-``order`` tensor in `spatial_dims` spatial dimensions. :rtype: TensorType """ return TensorType(order, order * (spatial_dims,), True, False, name)
[docs] @staticmethod def make_antisymmetric(order: int, spatial_dims: int, name: str = ""): """Returns a :class:`TensorType` representing an antisymmetric tensor :param order: The order of the tensor :type order: int :param spatial_dims: How many spatial dimensions? :type spatial_dims: int :param name: The name of the tensor, defaults to '' :type name: str, optional :return: A TensorType representing an antisymmetric order-``order`` tensor in `spatial_dims` spatial dimensions. :rtype: TensorType """ return TensorType(order, order * (spatial_dims,), False, True, name)
[docs] @staticmethod def from_array(X, symmetric=False, antisymmetric=False, name: str = ""): """Creates a :class:`TensorType` representing a particular array :param X: The array :type x: Union[jnp.ndarray,np.ndarray,pyadjoint_utils.jax_adjoint.ndarray] :param symmetric: Is the array symmetric? defaults to False :type symmetric: bool, optional :param antisymmetric: Is the array antisymmetric? defaults to False :type antisymmetric: bool, optional :param name: The name of the tensor, defaults to '' :type name: str, optional :returns: A :class:`TensorType` representing ``X`` :rtype: TensorType """ shp = X.shape if symmetric and antisymmetric: raise RuntimeError( "Cannot create a TensorType from an array that is both symmetric and antisymmetric! (That's a contradiction; pass True to only ONE of symmetric or antisymmetric in TensorType.from_array)" ) return TensorType(len(shp), shp, symmetric, antisymmetric, name)
[docs] def get_symmetrizer(self): """Returns a function that takes in a tensor of order ``self.order`` and makes it symmetric. :return: The symmetrizer for a tensor of order ``self.order`` :rtype: function """ if self.order <= 1: return lambda x: x elif self.order == 2: return symm elif order == 3: return symm_q3 elif order == 4: return symm_q4
[docs] def zeros_like(self) -> jnp.ndarray: """Returns an array of zeros of the shape ``self.shape``. :returns: jnp.zeros(self.shape) :rtype: jnp.ndarray """ return jnp.zeros(shape=self.shape)
[docs] def tensor_space_dimension(self) -> int: """Returns the dimension (as a vector space) of the tensor space containing tensors of this shape. :returns: dimension of the tensor space containing this TensorType :rtype: int """ N = max(self.shape) q = self.order if self.symmetric: if q == 2: return (N * (N + 1)) // 2 raise NotImplementedError if self.antisymmetric: if q == 2: return (N * (N - 1)) // 2 raise NotImplementedError return np.prod(self.shape)
def __hash__(self): return hash(tuple((self.order, self.shape, self.symmetric, self.antisymmetric))) def __ge__(self, other): if isinstance(other, TensorType): return self._selfge(other) elif isinstance(other, tuple): return True elif isinstance(other, int): return True else: return False def __lt__(self, other): return not self.__ge__(other) def _selfge(self, other): if self.order > other.order: return True elif self.order == other.order: if self.symmetric: if other.antisymmetric: return True elif other.symmetric: return True else: return False elif self.antisymmetric: if other.symmetric: return False elif other.antisymmetric: return True else: return False else: return False else: return False
[docs] def get_array_like(self, backend=None) -> jnp.ndarray: """ Constructs an example array with the right shape and symmetry :return: a tensor with the right shape and symmetry :rtype: jnp.ndarray """ if backend is None: backend = get_default_backend() if self.symmetric and self.order == 2: return backend.eye(max(self.shape)) elif self.order == 1: return backend.zeros(self.shape) arr = backend.zeros(self.shape, dtype=jnp.float32) upper_right_corner = tuple(max(0, s - 1) for s in self.shape) arr = backend.index_update(arr, upper_right_corner, 1.0) if self.symmetric: return symm(arr, backend=backend) elif self.antisymmetric: return antisymm(arr, backend=backend) else: return arr
[docs]class LeviCivitaType(TensorType): """A class that represents the Levi-Civita tensor""" def __new__(cls, order): if order == 2: order = 2 shape = levi_civita(2, get_default_backend()).shape symmetric = False antisymmetric = True return super(LeviCivitaType, cls).__new__( cls, order, shape, symmetric, antisymmetric ) elif order == 3: order = 3 shape = levi_civita(3, get_default_backend()).shape symmetric = False antisymmetric = True return super(LeviCivitaType, cls).__new__( cls, order, shape, symmetric, antisymmetric ) else: raise NotImplementedError
[docs]def type_from_array( X, rtol: float = 1.0e-5, name: str = "", backend=None ) -> TensorType: """ Like :meth:`TensorType.from_array`, but tries to detect if the matrix is symmetric or asymmetric. :param X: An array (JAX or numpy) :type X: Union[jnp.ndarray] :param rtol: The relative tolerance to use when determining if X is symmetric, antisymmetric, or both, defaults to 1.0e-5 :type rtol: float :param name: The name of the tensor, defaults to '' :type name: str, optional :return: An appropriate TensorType instance :rtype: TensorType """ if backend is None: backend = get_default_backend() if isinstance(X, jax_ndarray): X = X.unwrap(True) if backend.isscalar(X) or X.size == 1: return TensorType.make_scalar(name) shp = X.shape order = len(shp) if order == 1: return TensorType(order, shp, False, False, name) elif order == 2: eps_ij = levi_civita(2, backend) if shp == eps_ij.shape: if near(X, levi_civita(2, backend), rtol): return LeviCivitaType(order) return TensorType(order, shp, is_symm(X, rtol), is_antisymm(X, rtol), name) elif order == 3: # NOTE: do we ever care whether or not a 3rd order tensor is totally antisymmetric? I don't think so, other than for comparison with the Levi-Civita tensor eps_ijk = levi_civita(3, backend) if shp == eps_ijk.shape: if near(X, levi_civita(3, backend), rtol): return LeviCivitaType(order) return TensorType(order, shp, near(X, symm_q3(X), rtol), False, name) else: raise NotImplementedError
def evaluate_invariant_function( phi_i, scalar_invariant_func, form_invariant_func, *input_tensors ): scalar_invariants = scalar_invariant_func(*input_tensors) if len(phi_i) != scalar_invariants.size: raise ValueError( f"Must pass {scalar_invariants.size} scalar functions, not {len(phi_i)}" ) form_invariants = form_invariant_func(*input_tensors) return _evaluate_invariant(form_invariant, scalar_invariants, phi_i) def _evaluate_invariant(form_invariants, scalar_invariants, phi_i): global numpy_backend output_vec = numpy_backend.array([phi(scalar_invariants) for phi in phi_i]) return output_vec.T @ form_invariants
[docs]class InvariantInfo(NamedTuple): """A class that contains relevant information for computing invariants. For example, for a hemitropic CR in 3 spatial dimensions taking a symmetric and an antisymmetric tensor as inputs and outputs a symmetric second order tensor: :: info = InvariantInfo(3, (TensorType.make_symmetric(2,3), TensorType.make_antisymmetric(2,3), LeviCivitaType(3) ), TensorType.make_symmetric(2,3) ) """ spatial_dims: int input_types: Tuple[TensorType, ...] output_type: TensorType
[docs] def get_group_symbol(self, sanitize_input_types: bool = False): """ Returns a symbol representing the group this instance represents. :param sanitize_input_types: if True, this function will also return the input types without the Levi-Civita symbol, if it exists, default False :type sanitize_input_types: bool, optional :return: a string whose value is either :math:`O(2)`,:math:`SO(2)`,:math:`O(3)`, or :math:`SO(3)` :rtype: str """ if self.spatial_dims <= 1: raise ValueError( f"Cannot get an orthogonal group symbol for {self.spatial_dims} spatial dimensions!" ) contains_eps = False eps_id = None for i, tp in enumerate(self.input_types): if isinstance(tp, LeviCivitaType): contains_eps = True eps_id = i break grp = None if self.spatial_dims == 2: grp = "SO(2)" if contains_eps else "O(2)" elif self.spatial_dims == 3: grp = "SO(3)" if contains_eps else "O(3)" # NOTE: will not detect multiple instances of a LeviCivitaType. TODO: implement that. if sanitize_input_types: input_types = self.input_types if contains_eps: try: if eps_id == len(self.input_types) - 1: input_types = tuple(self.input_types[:-1]) elif eps_id == 0: input_types = tuple(self.input_types[1:]) else: input_types = tuple( self.input_types[:i] + self.input_types[i + 1 :] ) except IndexError: pass return grp, tuple(input_types) else: return grp
[docs] @staticmethod def from_arrays(output_example, *args, **kwargs): """Constructs an :class:`InvariantInfo` from arrays representing the output and inputs :param output_example: an array of the correct shape and symmetry/antisymmetry of the desired output :type output_example: Union[jnp.ndarray] :param args: an example of each of the input tensors :type args: Iterable[Union[jnp.ndarray]] :param rtol: the relative tolerance for detecting symmetry/antisymmetry and the Levi-Civita symbol, defaults to 1.0e-5 :type rtol: float :return: an InvariantInfo with the correct spatial dims (inferred from the first argument) and correct input_types for your inputs and output :rtype: InvariantInfo """ rtol = kwargs.get("rtol", 1.0e-5) types = [type_from_array(x, rtol=rtol) for x in args] spatial_dims = max(types[0].shape) for t in types: if len(t.shape) > 0 and max(t.shape) != spatial_dims: raise ValueError( f"Input type {t} has {max(t.shape)} spatial dimensions, but the first input type has {spatial_dims} spatial dimensions" ) output_type = type_from_array(output_example, rtol=rtol) return InvariantInfo(spatial_dims, types, output_type)
# needed for static_argnums parameter of jax.jit class HashableDict(dict): def _key(self): return tuple(sorted(self.items())) def __hash__(self): return hash(self._key()) def __setitem__(self, key, value): raise TypeError(f"{type(self)} does not support assignment!")
[docs]def get_invariant_functions( info: InvariantInfo, suppress_warning_print: Optional[bool] = False, fail_on_warning: Optional[bool] = False, backend: Optional[str] = None, ): """ This function builds two functions, one to compute the scalar invariants, and one to compute the form invariants. :param info: an InvariantInfo instance :type info: InvariantInfo :param suppress_warning_print: if True, don't print out warnings (this typically would be used if you get a warning about scalar or form-invariants not being available for a specific subset of the input types, and you know that this isn't a problem, e.g. because no such invariants exist for that subset), defaults to False :type suppress_warning_print: bool, optional :param fail_on_warning: if True, warnings become exceptions. Useful if you know that you should not get a warning for your inputs, and want to make sure that nothing changes in a way that breaks that assumption., defaults to False :type fail_on_warning: bool, optional :param backend: Which numpy backend to use? Must be one of 'jax', 'torch', or 'numpy'. Defaults to whichever is the CRIKit default numpy backend :type backend: str, optional :return: a tuple of two functions, the first of which generates the input scalar invariants (and places them into a backend numpy array), and the second of which generates the output form-invariant basis. :rtype: tuple """ if backend is None: backend = get_default_backend() else: backend = get_backend(backend) if not (isinstance(info, InvariantInfo)): raise TypeError( f"First parameter of get_scalar_invariant_function ({info}) is of type {type(info)}, not InvariantInfo!" ) dims = info.spatial_dims group, input_types = info.get_group_symbol(sanitize_input_types=True) output_type = info.output_type include_identity = output_type.order == 2 and output_type.symmetric include_eps = dims == 2 and output_type.order == 2 and output_type.antisymmetric old_input_types = input_types # sort inputs input_sort_idxs, input_types = _get_sorted_and_indices(input_types, reverse=True) input_types = _tag_value_chains(tuple(input_types)) input_sort_map = {itype: i for (itype, i) in zip(input_types, input_sort_idxs)} scalar_invts = [] form_invts = [] output_type = _nameless(output_type) for sub_types in powerset(input_types, exclude_empty_set=True): try: new_sub_types = tuple(sorted(_untagged_tuple(sub_types), reverse=True)) si_func, fi_func = _get_invariant_functions( dims, group, _untagged_tuple(new_sub_types, nameless=True), output_type ) scalar_invts.append((si_func, sub_types)) if fi_func: form_invts.append((fi_func, sub_types)) else: if not suppress_warning_print: logger.warning( f"""we don't have functions to compute form-invariants for all subsets of your input basis! Specifically, we do not have invariants for the input TensorType subset {_untagged_tuple(sub_types)}. Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't always exist for any arbitrary subset of every set of input types--but if you expect invariants to exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions. """ ) except ValueError as v: if not suppress_warning_print: logger.warning( f"""we don't have functions to compute scalar invariants for all subsets of your input basis! Specifically, we do not have invariants for the input TensorType subset {_untagged_tuple(sub_types)}. Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't always exist for any arbitrary subset of every set of input types--but if you expect invariants to exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions. """ ) if fail_on_warning: raise v def unified_scalar_invariant_func(imap, sinvts, backend, *args): invts = [] N = len(sinvts) for i in range(N): si_func, sub_types = sinvts[i] idx = [imap[sub_t] for sub_t in sub_types] inputs = [args[i] for i in idx] invts.append(_1darr(si_func(*inputs, backend), backend)) return backend.concatenate(invts) scalar_invts = tuple(scalar_invts) input_sort_map = HashableDict(input_sort_map) def scalar_invariant_func(*args): imap = input_sort_map sinvts = scalar_invts if len(args) != len(input_types): raise ValueError( f"Wrong number of arguments passed to scalar_invariant_func! Expected {len(input_types)}, but got {len(args)}!" ) return unified_scalar_invariant_func(imap, sinvts, backend, *args) def unified_form_invariant_func(imap, frminvts, backend, *args): finvts = [] N = len(frminvts) for i in range(N): fi_func, sub_types = frminvts[i] idx = [imap[sub_t] for sub_t in sub_types] inputs = [args[i] for i in idx] finvts += fi_func(*inputs, backend) """ NOTE: the above (making a list of indices then a list of inputs) works with jax.jit, but directly doing fi_func(args[imap[sub_t]] for sub_t in sub_types)) doesn't work (JAX complains and throws an exception) ALSO NOTE: putting the numpy_backend.stack in Python mode for now to deal with possibly appending the identity """ return finvts form_invts = tuple(form_invts) def form_invariant_func(*args): imap = input_sort_map frminvts = form_invts has_identity = include_identity has_eps = include_eps if len(args) != len(input_types): raise ValueError( f"Wrong number of arguments passed to form_invariant_func! Expected {len(input_types)}, but got {len(args)}!" ) extra_form_invts = [] if has_identity: extra_form_invts.append(backend.eye(dims)) # elif has_eps: # extra_form_invts.append(eps_ij) if extra_form_invts: return backend.stack( extra_form_invts + unified_form_invariant_func(imap, frminvts, backend, *args) ) return backend.stack( unified_form_invariant_func(imap, frminvts, backend, *args) ) return scalar_invariant_func, form_invariant_func
[docs]def get_invariant_descriptions( info: InvariantInfo, suppress_warning_print: Optional[bool] = False, fail_on_warning: Optional[bool] = False, html: Optional[bool] = None, ipython: Optional[bool] = None, backend: Optional[bool] = None, ): """ This function builds a string description of the scalar and form invariants that you would get from :func:`get_invariant_functions` with the same arguments you pass in here. :param info: an InvariantInfo instance :type info: InvariantInfo :param suppress_warning_print: if True, don't print out warnings (this typically would be used if you get a warning about scalar or form-invariants not being available for a specific subset of the input types, and you know that this isn't a problem, e.g. because no such invariants exist for that subset), defaults to False :type suppress_warning_print: bool, optional :param fail_on_warning: if True, warnings become exceptions. Useful if you know that you should not get a warning for your inputs, and want to make sure that nothing changes in a way that breaks that assumption., defaults to False :type fail_on_warning: bool, optional :param html: Return HTML instead of a plain string description? Useful for use inside Jupyter notebooks. Defaults to False :type html: bool, optional :param ipython: Is this being used in ipython mode? (e.g. in a Jupyter notebook) By default, tries to guess whether or not you are. If the default behavior is undesirable, set this parameter manually. :type ipython: bool, optional :param backend: Which numpy backend to use? Must be one of 'jax', 'torch', or 'numpy'. Defaults to whichever is the CRIKit default numpy backend :type backend: str, optional :return: a string describing the invariants :rtype: str """ if not (isinstance(info, InvariantInfo)): raise TypeError( f"First parameter of get_scalar_invariant_function ({info}) is of type {type(info)}, not InvariantInfo!" ) if backend is None: backend = get_default_backend() else: backend = get_backend(backend) class function_state: n_rank_0 = 0 n_rank_1 = 0 n_rank_2_s = 0 n_rank_2_a = 0 n_rank_3 = 0 n_rank_0_n = ( 0 # _n params count the number of this type that has a name already ) n_rank_1_n = 0 n_rank_2_s_n = 0 n_rank_2_a_n = 0 n_rank_3_n = 0 if _executing_in_ipython(): ipython = True if ipython is None else ipython # if we're in IPython mode, we definitely need HTML unless the parameter says otherwise html = ipython if html is None else html dims = info.spatial_dims group, input_types = info.get_group_symbol(sanitize_input_types=True) output_type = info.output_type include_identity = output_type.order == 2 and output_type.symmetric include_eps = dims == 2 and output_type.order == 2 and output_type.antisymmetric input_state = function_state() HEADER = ( _get_header_html(dims, input_types, input_state) if html else _get_header_plaintext(dims, input_types, input_state) ) # sort inputs input_sort_idxs, input_types = _get_sorted_and_indices(input_types, reverse=True) input_types = _tag_value_chains(tuple(input_types)) input_sort_map = {itype: i for (itype, i) in zip(input_types, input_sort_idxs)} scalar_invts = [] form_invts = [] output_type = _nameless(output_type) for sub_types in powerset(input_types, exclude_empty_set=True): try: new_sub_types = tuple(sorted(sub_types, reverse=True)) si_func, fi_func = _get_invariant_functions( dims, group, _untagged_tuple(new_sub_types, nameless=True), output_type ) sub_types = tuple(sorted(sub_types, reverse=True)) scalar_invts.append((si_func, sub_types)) if fi_func: form_invts.append((fi_func, sub_types)) else: if not suppress_warning_print: logger.warning( f"""we don't have functions to compute form-invariants for all subsets of your input basis! Specifically, we do not have invariants for the input TensorType subset {_untagged_tuple(sub_types)}. Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't always exist for any arbitrary subset of every set of input types--but if you expect invariants to exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions. """ ) except ValueError as v: if not suppress_warning_print: logger.warning( f"""we don't have functions to compute scalar invariants for all subsets of your input basis! Specifically, we do not have invariants for the input TensorType subset {_untagged_tuple(sub_types)}. Proceed with caution, as your basis may be incomplete! This may not be a problem--invariants don't always exist for any arbitrary subset of every set of input types--but if you expect invariants to exist for this input subset (and you have checked that that assumption is justified), we may be missing those functions. """ ) if fail_on_warning: raise v scalar_symbols = ["x", "y", "z"][input_state.n_rank_0_n :] vector_symbols = ["v", "u"][input_state.n_rank_1_n :] symm_symbols = ["A", "B", "C"][input_state.n_rank_2_s_n :] antisymm_symbols = ["W", "V"][input_state.n_rank_2_a_n :] r3_symbols = ["T", "S"][input_state.n_rank_3_n :] # if you change this, also change _get_symbol_id() below symbol_map = [ scalar_symbols, vector_symbols, symm_symbols, antisymm_symbols, r3_symbols, ] # build scalar invariant description scalar_invt_descrs = [] N = len(scalar_invts) range_start = 0 inputs = [] for i in range(N): si_func, sub_types = scalar_invts[i] stride = _get_num_retvals(si_func, sub_types, backend) if stride == 1: prepend = f"{range_start} : " else: prepend = f"({range_start}:{range_start + stride - 1}) : " if html: prepend = "<code>" + prepend + "</code>" range_start += stride param_names = _infer_param_names(si_func, sub_types, symbol_map) scalar_invt_descrs.append( prepend + _format_invt_function(si_func, sub_types, symbol_map, html) ) si_descr = _apply_final_formatting("\n\n".join(scalar_invt_descrs), html) form_invt_descrs = [] range_start = 0 if include_identity: form_invt_descrs.append( f"<code>0 : [] -> I_{dims}</code>" if html else f"0 : [] -> I_{dims}" ) range_start = 1 for i in range(len(form_invts)): fi_func, sub_types = form_invts[i] stride = _get_num_retvals_form_invt(fi_func, sub_types, backend) if stride == 1: prepend = f"{range_start} : " else: prepend = f"({range_start}:{range_start + stride - 1}) : " if html: prepend = "<code>" + prepend + "</code>" range_start += stride form_invt_descrs.append( prepend + _format_invt_function(fi_func, sub_types, symbol_map, html) ) fi_line = ( '<br><br><hr style="border: 1px dashed black"><br>' if html else "\n\n------------------------------------------------\n" ) fi_line += "Form Invariants:\n\n" fi_descr = _apply_final_formatting(fi_line + "\n\n".join(form_invt_descrs), html) if html: string = HEADER + si_descr + fi_descr + "</html>" if ipython: from IPython.core.display import HTML, display if html: string = HTML(string) return display(string) return HEADER + si_descr + fi_descr
def _executing_in_ipython(): try: sh = get_ipython().__class__.__name__ return True except NameError: # get_ipython() doesn't exist, so we proabably aren't in an IPython environment return False def _apply_final_formatting(string, html): if html: return string.replace("\n", "<br>") return string def _get_header_html(dims, input_types, input_state): style = """ <html> <head> <style> table, th, td { border: 1px solid black; } ul { margin: 0; } ul.dashed { list-style-type: none; } ul.dashed > li { text-indent: -5px; } ul.dashed > li:before { content: "-"; text-indent: -5px; } </style> </head> <body> <p>Legend</p> """ string = f""" <table style="width:100%"> <tr> <th>Symbol(s)</th> <th>Tensor Rank</th> <th>Symmetric</th> <th>Antisymmetric</th> </tr> <tr> <td>x, y, z</td> <td>0</td> <td>N/A</td> <td>N/A</td> </tr> <tr> <td>v, u</td> <td>1</td> <td>N/A</td> <td>N/A</td> </tr> <tr> <td>A, B, C</td> <td>2</td> <td>Yes</td> <td>No</td> </tr> <tr> <td>W, V</td> <td>2</td> <td>No</td> <td>Yes</td> </tr> <tr> <td>T, S</td> <td>3</td> <td>Any</td> <td>Any</td> </tr> </table> <br><br> Special symbols:<br> <ul class="dashed"> <li> <code>I_{dims}</code> (rank-two identify tensor, A.K.A. identity matrix)</li> </ul> <br> Operations:<br> <ul class="dashed"> <li> <code>_tprod(x, y) = numpy_backend.tensordot(x, y, axes=0) (tensor product)</code></li> <br> <li> <code>symm(x) = x + x.T</code></li><br> <li> <code>antisymm(x) = x - x.T</code></li> </ul> <br> <hr style="border: 1px dashed black"> <br> Input tensors: <br> <code>{_get_type_symbols(input_types, input_state)}</code> <br> <hr style="border: 1px dashed black"> <br> Scalar invariants:<br><br> """ return style + string def _get_header_plaintext(dims, input_types, input_state): return f""" Legend: ------------------------------------------------ Symbol(s) | Tensor Rank | Symmetric | Antisymmetric ------------------------------------------------ x, y, z | 0 | N/A | N/A ------------------------------------------------ v, u | 1 | N/A | N/A ------------------------------------------------ A, B, C | 2 | Yes | No ------------------------------------------------ W, V | 2 | No | Yes ------------------------------------------------ T, S | 3 | Any | Any ------------------------------------------------ Special symbols: - I_{dims} (rank-two identify tensor, A.K.A. identity matrix) Operations: - _tprod(x, y) = numpy_backend.tensordot(x, y, axes=0) (tensor product) - symm(x) = x + x.T - antisymm(x) = x - x.T ------------------------------------------------ Input tensors: {_get_type_symbols(input_types, input_state)} ------------------------------------------------ Scalar invariants: """ def _get_symbol_replacements( names: Union[List[str], Tuple[str]], types: Iterable[Union[TensorType, Tuple[TensorType, int]]], symbol_map: List[List[str]], ) -> List[Tuple[str, str]]: # names are returned by _infer_param_names() and are the same length as types replacements = set() for i, ttype in enumerate(types): if not isinstance(ttype, TensorType): ttype, j = ttype tid = _get_symbol_id(ttype) map_to = symbol_map[tid][j] else: tid = _get_symbol_id(ttype) map_to = names[i] candidate = symbol_map[tid][0] if candidate != map_to: replacements.add((candidate, map_to)) return replacements def _replace_symbols( names: Union[List[str], Tuple[str]], types: Iterable[Union[TensorType, Tuple[TensorType, int]]], symbol_map: List[List[str]], line: str, ) -> str: replacements = _get_symbol_replacements(names, types, symbol_map) for replacement in replacements: line = line.replace(*replacement).replace( replacement[0].lower(), replacement[1].lower() ) return line def _format_invt_function(f, types, symbol_map, html): param_names = _infer_param_names(f, types, symbol_map) lines = _strip_function_header(inspect.getsourcelines(f)[0]) if len(lines) == 0: return ( f"<code>{param_names} -> {param_names}</code>" if html else f"{param_names} -> {param_names}" ) elif len(lines) == 1: body = [] retline = lines[0] else: retline = -1 for i, line in enumerate(lines): if "return" in line: retline = i break retline, body = lines[retline:], lines[:retline] retline = "\n".join(retline) body = _format_body(body) # retline = _replace_symbols(param_names, types, symbol_map, # _format_return_line(retline.lstrip(' '))) body = ",\n".join(body).rstrip(",\n") func_descr = f"{param_names} -> " replacement_str = "[" for r in inspect.signature(f).parameters: replacement_str += "'" + r + "', " replacement_str = replacement_str.rstrip(", ") + "]" func_descr += replacement_str + " -> " func_descr += _format_return_line(retline.lstrip(" ")) if len(lines) > 1: func_descr += "where\n" + body if html: return "<code>" + func_descr + "</code>" return func_descr def _format_body(body): return [x.lstrip(" ").rstrip("\n") for x in body] def _get_num_retvals(f, input_types, backend): inputs = [_untagged_value(t).get_array_like() for t in input_types] return f(*inputs, backend).size def _get_num_retvals_form_invt(f, input_types, backend): inputs = [_untagged_value(t).get_array_like() for t in input_types] return len(f(*inputs, backend)) def _format_return_line(line): if "array(" in line: idx = line.find("array(") + len("array(") line = line.rstrip("\n").rstrip(")")[idx:] + "\n" elif line.startswith("return "): return line[len("return ") :] return line def _strip_function_header(f_lines): return f_lines[2:] def _get_symbol_id(ttype): if ttype.order < 2: return ttype.order if ttype.order == 2: return 2 if ttype.symmetric else 3 return 4 def _infer_param_names(f, input_types, symbol_map): names = [] for ttype in input_types: if isinstance(ttype, TensorType): # if it's a pure TensorType, it must be the first of its type, so get the first symbol name = ( symbol_map[_get_symbol_id(ttype)][0] if ttype.name == "" else ttype.name ) names.append(name) else: ttype, i = ttype name = ( symbol_map[_get_symbol_id(ttype)][i] if ttype.name == "" else ttype.name ) names.append(name) return names def _get_type_symbols(ts, state): def _handle_rank_0(t): state.n_rank_0 += 1 if state.n_rank_0 == 1: if t.name == "": return "x" else: state.n_rank_0_n += 1 return t.name elif state.n_rank_0 == 2: if t.name == "": return "y" else: state.n_rank_0_n += 1 return t.name elif state.n_rank_0 == 3: if t.name == "": return "z" else: state.n_rank_0_n += 1 return t.name else: raise ValueError( "Currently can only represent 3 scalars in the input types!" ) def _handle_rank_1(t): # global n_rank_1 if state.n_rank_1 == 0: state.n_rank_1 = 1 if t.name == "": return "v" else: state.n_rank_1_n += 1 return t.name elif state.n_rank_1 == 1: state.n_rank_1 = 2 if t.name == "": return "u" else: state.n_rank_1_n += 1 return t.name else: raise ValueError( "Currently can only represent 2 vectors in the input types!" ) def _handle_rank_2(t): # global n_rank_2_s if t.symmetric: state.n_rank_2_s += 1 if state.n_rank_2_s == 1: if t.name == "": return "A" else: state.n_rank_2_s_n += 1 return t.name elif state.n_rank_2_s == 2: if t.name == "": return "B" else: state.n_rank_2_s_n += 1 return t.name elif state.n_rank_2_s == 3: if t.name == "": return "C" else: state.n_rank_2_s_n += 1 return t.name else: raise ValueError( "Currently can only represent 3 symmetric rank-two tensors in the input types!" ) else: state.n_rank_2_a += 1 if state.n_rank_2_a == 1: if t.name == "": return "W" else: state.n_rank_2_a_n += 1 return t.name elif state.n_rank_2_a == 2: if t.name == "": return "V" else: state.n_rank_2_a_n += 1 return t.name else: raise ValueError( "Currently can only represent 2 antisymmetric rank-two tensors in the input types!" ) def _handle_rank_3(t): state.n_rank_3 += 1 if state.n_rank_3 == 1: if t.name == "": return "T" else: state.n_rank_3_n += 1 return t.name elif state.n_rank_3 == 2: if t.name == "": return "S" else: state.n_rank_3_n += 1 return t.name else: raise ValueError( "Currently can only represent 2 rank-three tensors in the input types!" ) handler_map = { 0: lambda x: _handle_rank_0(x), 1: lambda x: _handle_rank_1(x), 2: lambda x: _handle_rank_2(x), 3: lambda x: _handle_rank_3(x), } dispatcher = lambda t: handler_map[t.order](t) return [dispatcher(_untagged_value(t)) for t in ts] # tags chains of the same value T (e.g. T,T,T,T,... becomes T,(T,1),(T,2),...) def _tag_value_chains(indexable): val = list(indexable) N = len(indexable) i = 1 while i < N: if val[i - 1] == val[i]: # chain starts here T = val[i] n = 1 while val[i] == T: val[i] = (T, n) n += 1 i += 1 if i >= N: break else: i += 1 return tuple(val) def _get_sorted_and_indices(l, reverse=True): if len(l) == 0: return (), () srtd = sorted(enumerate(l), key=itemgetter(1), reverse=reverse) return tuple(x[0] for x in srtd), tuple(x[1] for x in srtd) def _nameless(x): return TensorType(x.order, x.shape, x.symmetric, x.antisymmetric) def _untagged_value(val, nameless=False): if isinstance(val, TensorType): return _nameless(val) if nameless else val return _nameless(val[0]) if nameless else val[0] def _untagged_tuple(tpl, nameless=False): return tuple(_untagged_value(x, nameless) for x in tpl) def _1darr(x, backend): return backend.atleast_1d(x) def _get_invariant_functions(dims, group, input_types, output_type): missing_dims_err_str = f"The number of dimensions you passed to get_scalar_invariant_function through the info parameter ({dims}) is not supported. Supported numbers of dimensions are currently: (2,3)" missing_group_err_str = f"The symmetry group you passed in (of type {type(group)}) is not among the currently supported set of symmetry groups (O(2),SO(2),O(3),SO(3))" missing_num_inputs_err_str = f"The number of inputs you passed ({len(input_types)}) is not among the currently-supported set of input sizes (1,2,3,4)." missing_inputs_err_str = f"The input types you passed ({input_types}) are not among the currently-supported set of input types in {dims} dimensions under the symmetry group {group}." missing_outputs_err_str = f"The output type you passed ({output_type}) along with the combination of input types you passed ({input_types}) does not have a form-invariant function implemented at the time, or it does not exist." try: group_table = _get_item( _scalar_invariant_function_table, dims, missing_dims_err_str ) inputs_table = _get_item(group_table, group, missing_group_err_str) inputs_table = _get_item( inputs_table, len(input_types), missing_num_inputs_err_str ) scalar_func, outputs_table = _get_item( inputs_table, input_types, missing_inputs_err_str ) except ValueError as v: raise v try: # get form-invariants output_basis_func = _get_item( outputs_table, output_type, missing_outputs_err_str ) return scalar_func, output_basis_func except ValueError as v: return scalar_func, None def _get_item(table, key, err_msg): val = None try: val = table[key] except KeyError: raise ValueError(err_msg) return val
[docs]def register_invariant_functions( info: InvariantInfo, scalar_invariant_func, form_invariant_func, overwrite_existing=False, jit=False, backend=None, ): """Register a scalar and form-invariant computing function for a given InvariantInfo. :param info: an InvariantInfo containing the relevant information about the inputs and outputs of functions with this symmetry. :type info: InvariantInfo :param scalar_invariant_func: a function that returns a single jax.numpy.ndarray contaning the scalar invariants for the inputs. :type scalar_invariant_func: Callable :param form_invariant_func: a function that returns a Python list of jax.numpy.ndarray instances representing the form-invariants for the inputs :type form_invariant_func: Callable :param overwrite_existing: if True, and the InvariantInfo you pass describes an existing set of invariants, replace those with your function. You should NEVER set this to True unless you really know what you're doing. If you want to overwrite one function but not the other (e.g. insert a form-invariant for a scenario where the scalar invariant already exists), you can also pass a pair of bools, one for the scalar invariant function and one for the form invariant function. defaults to False :type overwrite_existing: Union[bool, Tuple[bool, bool]], optional :param jit: if True, call backend.jit() on scalar_invariant_func() and form_invariant_func() if it exists, defaults to False :type jit: bool, optional :param backend: Which numpy backend to use? Must be one of 'jax', 'torch', or 'numpy'. Defaults to whichever is the CRIKit default numpy backend :type backend: str, optional :return: None, makes your functions available to get_invariant_functions() """ if backend is None: backend = get_default_backend() else: backend = get_backend(backend) if jit: scalar_invariant_func = backend.jit(scalar_invariant_func) form_invariant_func = backend.jit(form_invariant_func) dims = info.spatial_dims group, input_types = info.get_group_symbol(sanitize_input_types=True) output_type = info.output_type input_types = tuple(sorted(input_types, reverse=True)) global _scalar_invariant_function_table if dims in _scalar_invariant_function_table: group_table = _scalar_invariant_function_table[dims] else: group_table = {group: {}} _scalar_invariant_function_table[dims] = group_table if group in group_table: n_input_table = group_table[group] else: n_input_table = {N: {}} group_table[group] = n_input_table N = len(input_types) if N in n_input_table: inputs_table = n_input_table[N] else: inputs_table = {tuple(input_types): [scalar_invariant_func, {}]} n_input_table[N] = inputs_table if isinstance(overwrite_existing, (int, bool)): overwrite_existing_scalar = overwrite_existing_form = overwrite_existing else: # if it's not a single bool/int, it must be a pair of them (or else we need to # throw an exception anyway, might as well let it be the default one from unpacking # a non-iterable object overwrite_existing_scalar, overwrite_existing_form = overwrite_existing if input_types in inputs_table: if not overwrite_existing_scalar: logger.warning( f"Scalar invariants for the inputs combination {input_types} already exists! Skipping this replacement for now. To overwrite this function, pass overwrite_existing=True to register_invariant_functions(). ONLY do this if you know what you are doing!" ) else: extant, outputs_table = inputs_table[input_types] inputs_table[input_types] = (scalar_invariant_func, outputs_table) else: outputs_table = {} inputs_table[input_types] = (scalar_invariant_func, outputs_table) if output_type in outputs_table: if not overwrite_existing_form: logger.warning( f"Form-invariant functions for this inputs and outputs combination (inputs {input_types}, output {output_type}) already exist! Skipping this replacement for now. To overwrite this function, pass overwrite_existing=True to register_invariant_functions(). ONLY do this if you know what you're doing!" ) else: outputs_table[output_type] = form_invariant_func else: outputs_table[output_type] = form_invariant_func
""" Because Python's parsing rules are kinda weird when dictionaries are involved, all of the invariant- calculating functions have to be first. If you're reading this for the first time, skip down to the tables below first. """ def _2d_hmt_1_r3_o_r3(T, backend): return [T, T @ levi_civita(2, backend)] def _2d_hmt_1sr2_o_s2(A, backend): return [A, A @ levi_civita(2, backend) - levi_civita(2, backend) @ A] def _2d_hmt_1_vec_o_vec(v, backend): return [v, levi_civita(2, backend) @ v] def _2d_hmt_1_vec_o_s2(v, backend): ev = levi_civita(2, backend) @ v return [_tprod(v, v, backend), _tprod(v, ev, backend) + _tprod(ev, v, backend)] def _2d_hmt_1_vec_o_s3(v, backend): return [_tprod(v, _tprod(v, v, backend), backend)] def _2d_hmt_1_vec_o_r3(v, backend): vv = _tprod(v, v) ev = levi_civita(2, backend) @ v return [ _tprod(v, vv, backend), _tprod(vv, ev, backend), _tprod(v, backend.eye(2), backend), _tprod(ev, backend.eye(2), backend), ] def _2d_iso_1_vec_o_vec(v, backend): return [v] def _2d_iso_1_vec_o_s2(v, backend): return [_tprod(v, v, backend)] def _2d_iso_1_vec_o_r3(v, backend): return [_tprod(v, _tprod(v, v, backend), backend)] + tbrace( _tprod(v, backend.eye(2), backend), backend ) def _2d_iso_1r3_o_3(T, backend): return [T] def _2d_iso_1sr2_o_s2(A, backend): return [A] def _2d_iso_1ar2_o_a2(W, backend): return [W] def _2d_iso_2sr2_o_a2(A, B, backend): return [A @ B - B @ A] def _2d_iso_1sr2_1ar2_o_s2(A, W, backend): return [A @ W - W @ A] def _2d_iso_1sr2_1_vec_o_vec(A, v, backend): return [A @ v] def _2d_iso_1sr2_1_vec_o_a2(A, v, backend): av = A @ v return [_tprod(v, av, backend) - A @ _tprod(v, v, backend)] def _2d_iso_1sr2_1_vec_o_s3(A, v, backend): return tbrace(_tprod(A @ v, backend.eye(2), backend), backend) def _2d_iso_1ar2_1_vec_o_s3(W, v, backend): return [symm_q3(_tprod(v, _tprod(v, W @ v, backend), backend), backend)] def _2d_iso_2_vec_o_s2(v, u, backend): return [_tprod(v, u, backend) + _tprod(u, v, backend)] def _2d_iso_2_vec_o_a2(v, u, backend): return [_tprod(v, u, backend) - _tprod(u, v, backend)] def _2d_iso_2_vec_o_s3(v, u, backend): return [symm_q3(_tprod(v, _tprod(v, u, backend), backend), backend)] def _2d_iso_1r3_1_vec_o_vec(T, v, backend): return [tv(T, v, backend)] def _2d_iso_1r3_1_vec_o_s2(T, v, backend): return [Tv(T, v, backend)] def _2d_iso_1r3_1_vec_o_a2(T, v, backend): ttv = tv(T, v) return [_tprod(v, ttv, backend) - _tprod(ttv, v, backend)] def _2d_iso_1r3_1sr2_o_vec(T, A, backend): return tA(T, A, backend) def _2d_iso_1r3_1sr2_o_s2(T, A, backend): tta = tA(T, A) return [_tprod(tta, tta, backend)] def _2d_iso_1r3_1sr2_o_a2(T, A, backend): tta = tA(T, A) ata = A @ tta return [_tprod(tta, ata, backend) - _tprod(ata, tta, backend)] def _2d_iso_1r3_1sr2_o_3(T, A, backend): tta = tA(T, A) return [symm_q3(_tprod(A @ tta, A, backend), backend)] + tbrace( _tprod(tta, backend.eye(2), backend), backend ) def _2d_iso_2_r3_o_a2(T, S, backend): return TinnerS(T, S, backend) def _3d_hmt_1sr2_o_s2(A, backend): return [A, A @ A] def _3d_hmt_1ar2_o_s2(W, backend): return [W @ W] def _3d_hmt_1ar2_o_a2(W, backend): return [W] def _3d_hmt_1ar2_o_vec(W, backend): return [axial_vector(W, backend)] def _3d_hmt_1vec_o_s2(v, backend): return [backend.tensordot(v, v, axes=0)] def _3d_hmt_1vec_o_a2(v, backend): return [backend.einsum("ijk,k -> ij", levi_civita(3, backend), v, optimize=True)] def _3d_hmt_1vec_o_vec(v, backend): return [v] def _3d_hmt_2sr2_o_s2(A, B, backend): AB = A @ B BA = B @ A return [AB + BA, A @ AB + BA @ A, AB @ B + B @ BA] def _3d_hmt_2sr2_o_a2(A, B, backend): AB = A @ B BA = B @ A BAA = BA @ A BBA = B @ BA return [ AB - BA, A @ AB - BAA, AB @ B - BBA, A @ BAA - A @ A @ BA, B @ AB @ B - BBA @ B, ] def _3d_hmt_2sr2_o_vec(A, B, backend): AB = A @ B ABB = AB @ B return [ axial_vector(AB, backend), axial_vector(A @ AB, backend), axial_vector(ABB, backend), axial_vector(AB @ A @ A, backend), axial_vector(B @ ABB, backend), ] def _3d_hmt_1s1ar2_o_s2(A, W, backend): AW = A @ W WA = W @ A W2 = W @ W return [AW - WA, AW @ W + W @ WA, WA @ W2 - W2 @ AW, A @ AW - WA @ A] def _3d_empty(backend): return backend.array([]) def _2d_empty(backend): return backend.array([]) def _2d_idf(backend): return backend.eye(2) def _3d_idf(backend): return backend.eye(3) def _3d_hmt_1s1ar2_o_a2(A, W, backend): AW = A @ W WA = W @ A return [AW + WA, AW @ W - W @ WA] def _3d_hmt_1s1ar2_o_vec(A, W, backend): AB = A @ B return [ axial_vector(AB, backend), axial_vector(A @ AB, backend), axial_vector(AB @ B, backend), axial_vector(AB @ A @ A, backend), axial_vector(B @ AB @ B, backend), ] def _3d_hmt_2ar2_o_s2(W, V, backend): WV = W @ V VW = V @ W return [WV + VW, W @ WV - VW @ W, WV @ V - V @ VW] def _3d_hmt_2ar2_o_a2(W, V, backend): return [W @ V - V @ W] def _3d_hmt_2ar2_o_vec(W, V, backend): return [axial_vector(W @ V, backend)] def _3d_hmt_1ar21vec_o_s2(W, v, backend): ev = backend.einsum("ijk,k -> ij", levi_civita(3, backend), v) wev = W @ ev return [ symm(_tprod(v, W @ v, backend), backend), symm(wev, backend), symm(W @ wev, backend), ] def _3d_hmt_1ar21vec_o_a2(W, v, backend): return [ antisymm(W @ backend.einsum("ijk,k -> ij", levi_civita(3, backend), v), backend) ] def _3d_hmt_1ar21vec_o_vec(W, v, backend): return [W @ v] def _3d_hmt_1sr21vec_o_s2(A, v, backend): return [ symm(_tprod(v, A @ v, backend), backend), symm(A @ backend.einsum("ijk,k -> ij", levi_civita(3, backend), v), backend), symm(_tprod(v, backend.cross(v, A @ v), backend), backend), ] def _3d_hmt_1sr21vec_o_a2(A, v, backend): return [ antisymm( A @ backend.einsum("ijk,k -> ij", levi_civita(3, backend), v), backend ), antisymm(_tprod(v, A @ v, backend), backend), ] def _3d_hmt_1sr21vec_o_vec(A, v, backend): Av = A @ v return [Av, backend.cross(v, Av)] def _3d_iso_1sr21vec_o_s2(A, v, backend): av = A @ v return [ symm(_tprod(v, av, backend), backend), symm(_tprod(v, A @ av, backend), backend), ] def _3d_iso_1sr21vec_o_a2(A, v, backend): av = A @ v aav = A @ av return [ antisymm(_tprod(v, av, backend), backend), antisymm(_tprod(v, aav, backend), backend), antisymm(_tprod(av, aav, backend), backend), ] def _3d_iso_1sr21vec_o_vec(A, v, backend): av = A @ v return [av, A @ av] def _3d_iso_2vec_o_s2(v, u, backend): return [symm(_tprod(v, u, backend), backend)] def _3d_iso_2vec_o_a2(v, u, backend): return [antisymm(_tprod(v, u, backend), backend)] def _3d_hmt_2vec_o_s2(v, u, backend): return [symm(_tprod(v, u, backend), backend)] def _3d_hmt_2vec_o_a2(v, u, backend): return [antisymm(_tprod(v, u, backend), backend)] def _3d_hmt_2vec_o_vec(v, u, backend): return [backend.cross(v, u)] def _3d_hmt_3sr2_o_s2(A, B, C, backend): AB = A @ B BC = B @ C AC = A @ C A2 = A @ A B2 = B @ B C2 = C @ C ABC = AB @ C return [ABC, A @ ABC, B2 @ AC, C2 @ AB, A2 @ B2 @ C, B2 @ C2 @ A, C2 @ A2 @ B] def _3d_hmt_3sr2_o_a2(A, B, C, backend): AB = A @ B BC = B @ C AC = A @ C CB = C @ B return [AB @ C - CB @ A + BC @ A - A @ CB + C @ AB - B @ AC] def _3d_hmt_3sr2_o_vec(A, B, C, backend): BC = B @ C return [axial_vector(A @ BC + BC @ A + C @ A @ B, backend)] def _tprod(x, y, backend): return backend.tensordot(x, y, axes=0) def _3d_iso_2sr21vec_o_a2(A, B, v, backend): return [ antisymm(_tprod(A @ v, B @ v, backend), backend) + antisymm(_tprod(v, antisymm(A @ B) @ v, backend), backend) ] def _3d_iso_2sr21vec_o_vec(A, B, v, backend): av = A @ v bv = B @ v return [A @ bv - B @ av] def _3d_iso_1s1ar21vec_o_a2(A, W, v, backend): wv = W @ v av = A @ v return [A @ wv + W @ av] def _3d_iso_2ar21vec_o_vec(W, V, v, backend): vv = V @ v wv = W @ v return [W @ vv - V @ wv] def _3d_iso_1ar22vec_o_s2(W, v, u, backend): vtu = antisymm(_tprod(v, u, backend), backend) return [W @ vtu + vtu @ W] def _3d_iso_1ar22vec_o_a2(W, v, u, backend): vtu = antisymm(_tprod(v, u, backend), backend) return [W @ vtu - vtu @ W] def _3d_iso_1sr22vec_o_s2(A, v, u, backend): vtu = antisymm(_tprod(v, u, backend), backend) return [A @ vtu - vtu @ A] def _3d_iso_1sr22vec_o_a2(A, v, u, backend): vtu = antisymm(_tprod(v, u, backend), backend) return [A @ vtu + vtu @ A] def _3d_isotropic_2s_2_vec(A, B, v, u, backend): return backend.inner(v, commutator_action(A, B, u)) def _3d_isotropic_1s_1a_2_vec(A, W, v, u, backend): return backend.inner(v, anticommutator_action(A, W, u)) def _3d_isotropic_2a_2_vec(W, V, v, u, backend): return backend.inner(v, commutator_action(W, V, u)) def _3d_hemitropic_1_vec(v, backend): return backend.inner(v, v) def _3d_isotropic_1s_rank_2_1_vec(A, v, backend): Av = A @ v return backend.concatenate( [ backend.atleast_1d(backend.inner(v, Av)), backend.atleast_1d(backend.inner(v, A @ Av)), ] ) def _3d_isotropic_1a_rank_2_1_vec(W, v, backend): return backend.inner(v, W @ (W @ v)) def _3d_isotropic_2_vec(v, u, backend): return backend.inner(v, u) def _3d_isotropic_2s_rank_2_1_vec(A, B, v, backend): bv = B @ v return backend.inner(v, A @ bv) def _3d_isotropic_1s_1a_rank_2_1_vec(A, W, v, backend): Wv = W @ v AWv = A @ Wv return backend.concatenate( [ backend.atleast_1d(backend.inner(v, AWv)), backend.atleast_1d(backend.inner(v, A @ AWv)), backend.atleast_1d(backend.inner(v, W @ (A @ (W @ Wv)))), ] ) def _3d_isotropic_2a_rank_2_1_vec(W, V, v, backend): Vv = V @ v return backend.concatenate( [ backend.atleast_1d(backend.inner(v, W @ Vv)), backend.atleast_1d(backend.inner(v, W @ (W @ Vv))), backend.atleast_1d(backend.inner(v, W @ (V @ Vv))), ] ) def _3d_isotropic_1s_rank_2_2_vec(A, v, u, backend): Au = A @ u return backend.concatenate( [ backend.atleast_1d(backend.inner(v, Au)), backend.atleast_1d(backend.inner(v, A @ Au)), ] ) def _3d_isotropic_1a_rank_2_2_vec(W, v, u, backend): Wu = W @ u return backend.concatenate( [ backend.atleast_1d(backend.inner(v, Wu)), backend.atleast_1d(backend.inner(v, W @ Wu)), ] ) def _3d_hemitropic_1s_rank_2(A, backend): A2 = A @ A return backend.concatenate( [ backend.atleast_1d(backend.trace(A)), backend.atleast_1d(backend.trace(A2)), backend.atleast_1d(backend.trace(A @ A2)), ] ) def _3d_hemitropic_1a_rank_2(W, backend): return backend.trace(W @ W) def _3d_hemitropic_1_vec(v, backend): return backend.inner(v, v) def _3d_hemitropic_2s_rank_2(A, B, backend): A2 = A @ A B2 = B @ B return backend.concatenate( [ backend.atleast_1d(backend.trace(A @ B)), backend.atleast_1d(backend.trace(A2 @ B)), backend.atleast_1d(backend.trace(A @ B2)), backend.atleast_1d(backend.trace(A2 @ B2)), ] ) def _3d_hemitropic_2a_rank_2(W, V, backend): return backend.trace(W @ V) def _3d_hemitropic_1s_1a_rank_2(A, W, backend): A2 = A @ A W2 = W @ W A2W2 = A2 @ W2 return backend.concatenate( [ backend.atleast_1d(backend.trace(A @ W2)), backend.atleast_1d(backend.trace(A2W2)), backend.atleast_1d(backend.trace(A2W2 @ A @ W)), ] ) def _3d_hemitropic_1s_rank_2_1_vec(A, v, backend): av = A @ v aav = A @ av return backend.concatenate( [ backend.atleast_1d(backend.inner(v, A @ v)), backend.atleast_1d(backend.inner(v, aav)), backend.atleast_1d(scalar_triple_prod(v, av, aav, backend)), ] ) def _3d_hemitropic_1a_rank_2_1_vec(W, v, backend): return backend.inner(v, axial_vector(W, backend)) def _3d_hemitropic_2_vec(v, u, backend): return backend.inner(v, u) def _3d_hemitropic_3s_rank_2(A, B, C, backend): return backend.trace(A @ B @ C) def _3d_hemitropic_2s_1a_rank_2(A, B, W, backend): AB = A @ B BW = B @ W return backend.concatenate( [ backend.atleast_1d(backend.trace(AB @ W)), backend.atleast_1d(backend.trace(A @ AB @ W)), backend.atleast_1d(backend.trace(AB @ BW)), backend.atleast_1d(backend.trace(A @ W @ W @ BW)), ] ) def _3d_hemitropic_1s_2a_rank_2(A, W, V, backend): AW = A @ W AWV = AW @ V return backend.concatenate( [ backend.atleast_1d(backend.trace(AWV)), backend.atleast_1d(backend.trace(AW @ W @ V)), backend.atleast_1d(backend.trace(AWV @ V)), ] ) def _3d_hemitropic_3a_rank_2(W, V, U, backend): return backend.trace(W @ V @ U) def _3d_hemitropic_2s_rank_2_1_vec(A, B, v, backend): AB = A @ B return backend.concatenate( [ backend.atleast_1d(backend.inner(v, axial_vector(AB, backend))), backend.atleast_1d(backend.inner(v, axial_vector(A @ AB, backend))), backend.atleast_1d(backend.inner(v, axial_vector(AB @ B, backend))), backend.atleast_1d(scalar_triple_prod(v, A @ v, B @ v, backend)), ] ) def _3d_hemitropic_1s_1a_rank_2_1_vec(A, W, v, backend): AW = A @ W return backend.concatenate( [ backend.atleast_1d(backend.inner(v, AW @ v)), backend.atleast_1d(backend.inner(v, axial_vector(AW))), backend.atleast_1d(backend.inner(v, axial_vector(AW @ W))), ] ) def _3d_hemitropic_2a_rank_2_1_vec(W, V, v, backend): return backend.inner(v, axial_vector(W @ V, backend)) def _3d_hemitropic_1s_rank_2_2_vec(A, v, u, backend): au = A @ u return backend.concatenate( [ backend.atleast_1d(backend.inner(v, au)), backend.atleast_1d(scalar_triple_prod(v, u, A @ v, backend)), backend.atleast_1d(scalar_triple_prod(v, u, au, backend)), ] ) def _3d_hemitropic_1a_rank_2_2_vec(W, v, u, backend): return backend.inner(v, W @ u) def _3d_hemitropic_3_vec(v, u, w, backend): return scalar_triple_prod(v, u, w, backend) """ 2-d versions of the above (and more) where applicable """ def _2d_isotropic_1s_rank_2(A, backend): return backend.concatenate( [backend.atleast_1d(backend.trace(A)), backend.atleast_1d(backend.trace(A @ A))] ) def _2d_isotropic_1a_rank_2(W, backend): return backend.trace(W @ W) def _2d_isotropic_1_vec(v, backend): return backend.inner(v, v) def _2d_isotropic_1_rank_3(T, backend): return TcontrS(T, T, backend) def _2d_isotropic_2s_rank_2(A, B, backend): return backend.trace(A @ B) def _2d_isotropic_2a_rank_2(W, V, backend): return backend.trace(W @ V) def _2d_isotropic_1s_rank_2_1_vec(A, v, backend): return backend.inner(v, A @ v) def _2d_isotropic_2_vec(v, u, backend): return backend.inner(v, u) def _2d_isotropic_2_rank_3(T, S, backend): return TcontrS(T, S, backend) def _2d_isotropic_1_rank_3_1_vec(T, v, backend): return backend.inner(v, tv(T, v, backend)) def _2d_isotropic_1_rank_3_1s_rank_2(T, A, backend): ta = tA(T, A) return backend.inner(tA, A @ tA) def _2d_isotropic_1s_1a_rank_2(A, W, backend): return backend.array([]) def _2d_isotropic_1a_rank_2_1_vec(W, v, backend): return backend.array([]) def _2d_hemitropic_1a_rank_2(W, backend): return backend.trace(levi_civita(2, backend) @ W) """ These are tables mapping output tensor types to functions that compute the form-invariants that make up the output basis. """ _2d_hemitropic_1s_rank_2_outputs = { TensorType(2, (2, 2), True, False): _2d_hmt_1sr2_o_s2, } _2d_hemitropic_1a_rank_2_outputs = {} _2d_hemitropic_1_vec_outputs = { TensorType(1, (2,), False, False): _2d_hmt_1_vec_o_vec, TensorType( 2, ( 2, 2, ), True, False, ): _2d_hmt_1_vec_o_s2, TensorType(3, (2, 2, 2), True, False): _2d_hmt_1_vec_o_s3, TensorType(3, (2, 2, 2), False, False): _2d_hmt_1_vec_o_r3, } _2d_hemitropic_1_rank_3_outputs = { TensorType(3, (2, 2, 2), False, False): _2d_hmt_1_r3_o_r3 } _2d_isotropic_2s_rank_2_outputs = { TensorType(2, (2, 2), False, True): _2d_iso_2sr2_o_a2, } _2d_isotropic_2a_rank_2_outputs = {} _2d_isotropic_1s_1a_rank_2_outputs = { TensorType(2, (2, 2), True, False): _2d_iso_1sr2_1ar2_o_s2, } _2d_isotropic_1s_rank_2_1_vec_outputs = { TensorType(1, (2,), False, False): _2d_iso_1sr2_1_vec_o_vec, TensorType(2, (2, 2), False, True): _2d_iso_1sr2_1_vec_o_a2, TensorType(3, (2, 2, 2), True, False): _2d_iso_1sr2_1_vec_o_s3, } _2d_isotropic_1a_rank_2_1_vec_outputs = { TensorType(3, (2, 2, 2), True, False): _2d_iso_1ar2_1_vec_o_s3, } _2d_isotropic_2_vec_outputs = { TensorType(2, (2, 2), True, False): _2d_iso_2_vec_o_s2, TensorType(2, (2, 2), False, True): _2d_iso_2_vec_o_a2, TensorType(3, (2, 2, 2), True, False): _2d_iso_2_vec_o_s3, } _2d_isotropic_2_rank_3_outputs = { TensorType(2, (2, 2), False, True): _2d_iso_2_r3_o_a2, } _2d_isotropic_1_rank_3_1_vec_outputs = { TensorType(1, (2,), False, False): _2d_iso_1r3_1_vec_o_vec, TensorType(2, (2, 2), True, False): _2d_iso_1r3_1_vec_o_s2, TensorType(2, (2, 2), False, True): _2d_iso_1r3_1_vec_o_a2, } _2d_isotropic_1_rank_3_1s_rank_2_outputs = { TensorType(1, (2,), False, False): _2d_iso_1r3_1sr2_o_vec, TensorType(2, (2, 2), True, False): _2d_iso_1r3_1sr2_o_s2, TensorType(2, (2, 2), False, True): _2d_iso_1r3_1sr2_o_a2, TensorType(3, (2, 2, 2), False, False): _2d_iso_1r3_1sr2_o_3, } _2d_isotropic_1_vec_outputs = { TensorType(1, (2,), False, False): _2d_iso_1_vec_o_vec, TensorType(2, (2, 2), True, False): _2d_iso_1_vec_o_s2, TensorType(3, (2, 2, 2), False, False): _2d_iso_1_vec_o_r3, } _2d_isotropic_1_rank_3_outputs = { TensorType(3, (2, 2, 2), False, False): _2d_iso_1r3_o_3, } _2d_isotropic_1s_rank_2_outputs = { TensorType(2, (2, 2), True, False): _2d_iso_1sr2_o_s2, # TensorType(2,(2,2),False,True) : , } _2d_isotropic_1a_rank_2_outputs = { TensorType(2, (2, 2), False, True): _2d_iso_1ar2_o_a2, } _3d_identity_outputs = { TensorType(2, (3, 3), True, False): _3d_idf, } _2d_identity_outputs = { TensorType(2, (2, 2), True, False): _2d_idf, } _3d_hemitropic_1s_rank_2_outputs = { # rhs reads like "3-d hemitropic (1 symmetric rank 2) -> (symmetric rank 2) outputs function" TensorType(2, (3, 3), True, False): _3d_hmt_1sr2_o_s2, } _3d_hemitropic_1a_rank_2_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_1ar2_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_1ar2_o_a2, TensorType(1, (3,), False, False): _3d_hmt_1ar2_o_vec, } _3d_hemitropic_1_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_1vec_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_1vec_o_a2, TensorType(1, (3,), False, False): _3d_hmt_1vec_o_vec, } _3d_isotropic_1s_rank_2_1_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_iso_1sr21vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_1sr21vec_o_a2, TensorType(1, (3,), False, False): _3d_iso_1sr21vec_o_vec, } _3d_isotropic_2_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_iso_2vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_2vec_o_a2, } _3d_hemitropic_3s_rank_2_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_3sr2_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_3sr2_o_a2, TensorType(1, (3,), False, False): _3d_hmt_3sr2_o_vec, } _3d_hemitropic_3a_rank_2_outputs = { # TensorType(2,(3,3),False,True) : _3d_hmt_3ar2_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_3ar2_o_a2, # TensorType(1,(3,),False,False) : _3d_hmt_3ar2_o_vec, } _3d_hemitropic_2s_rank_2_1_vec_outputs = { # TensorType(2,(3,3),False,True) : None,#_3d_hmt_2sr21vec_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_2sr21vec_o_a2, # TensorType(1,(3,),False,False) : None, } _3d_hemitropic_1s_1a_rank_2_1_vec_outputs = { # TensorType(2,(3,3),False,True) : None,#_3d_hmt_1s1ar21vec_o_s2, # TensorType(2,(3,3),False,True) : None,#_3d_hmt_1s1ar21vec_o_a2, # TensorType(1,(3,),False,False) : None#_3d_hmt_1s1ar21vec_o_vec, } _3d_hemitropic_2a_rank_2_1_vec_outputs = { # TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_a2, # TensorType(1,(3,),False,False) : _3d_hmt_2ar21vec_o_vec, } _3d_hemitropic_1s_rank_2_2_vec_outputs = { # TensorType(2,(3,3),False,True) : _3d_hmt_1sr22vec_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_1sr22vec_o_a2, # TensorType(1,(3,),False,False) : _3d_hmt_1sr22vec_o_vec, } _3d_hemitropic_1a_rank_2_2_vec_outputs = { # TensorType(2,(3,3),False,True) : _3d_hmt_1ar22vec_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_1ar22vec_o_a2, # TensorType(1,(3,),False,False) : _3d_hmt_1ar22vec_o_vec, } _3d_hemitropic_3_vec_outputs = { # TensorType(2,(3,3),False,True) : _3d_hmt_3vec_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_3vec_o_a2, # TensorType(1,(3,),False,False) : _3d_hmt_3vec_o_vec, } _3d_isotropic_2s_rank_2_1_vec_outputs = { # TensorType(2,(3,3),False,True) : None,#_3d_hmt_2sr21vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_2sr21vec_o_a2, TensorType(1, (3,), False, False): _3d_iso_2sr21vec_o_vec, } _3d_isotropic_1s_1a_rank_2_1_vec_outputs = { # TensorType(2,(3,3),False,True) : None,#_3d_hmt_1s1ar21vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_1s1ar21vec_o_a2, # TensorType(1,(3,),False,False) : None#_3d_hmt_1s1ar21vec_o_vec, } _3d_isotropic_2a_rank_2_1_vec_outputs = { # TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_s2, # TensorType(2,(3,3),False,True) : _3d_hmt_2ar21vec_o_a2, TensorType(1, (3,), False, False): _3d_iso_2ar21vec_o_vec, } _3d_isotropic_1a_rank_2_2_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_iso_1ar22vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_1ar22vec_o_a2, # TensorType(1,(3,),False,False) : _3d_iso_1ar22vec_o_vec } _3d_isotropic_1s_rank_2_2_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_iso_1sr22vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_1sr22vec_o_a2, # TensorType(1,(3,),False,False) : _3d_iso_1sr22vec_o_vec } """ These tables map specific combinations of input tensor types to a pair containing a function to compute the input scalar invariants and one of the output tables from above. """ _3d_hemitropic_three_input_table = { ( TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False), ): (_3d_hemitropic_3s_rank_2, _3d_hemitropic_3s_rank_2_outputs), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True)) : (_3d_hemitropic_2s_1a_rank_2, # _3d_hemitropic_2s_1a_rank_2_outputs), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True), # TensorType(2,(3,3),False,True)) : (_3d_hemitropic_1s_2a_rank_2, # _3d_hemitropic_1s_2a_rank_2_outputs), # (TensorType(2,(3,3),False,True), # TensorType(2,(3,3),False,True), # TensorType(2,(3,3),False,True)) : (_3d_hemitropic_3a_rank_2, # _3d_hemitropic_3a_rank_2_outputs), ( TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False), ): (_3d_hemitropic_2s_rank_2_1_vec, _3d_hemitropic_2s_rank_2_1_vec_outputs), ( TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False), ): (_3d_hemitropic_1s_1a_rank_2_1_vec, _3d_hemitropic_1s_1a_rank_2_1_vec_outputs), ( TensorType(2, (3, 3), False, True), TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False), ): (_3d_hemitropic_2a_rank_2_1_vec, _3d_hemitropic_2a_rank_2_1_vec_outputs), ( TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False), TensorType(1, (3,), False, False), ): (_3d_hemitropic_1s_rank_2_2_vec, _3d_hemitropic_1s_rank_2_2_vec_outputs), ( TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False), TensorType(1, (3,), False, False), ): (_3d_hemitropic_1a_rank_2_2_vec, _3d_hemitropic_1a_rank_2_2_vec_outputs), ( TensorType(1, (3,), False, False), TensorType(1, (3,), False, False), TensorType(1, (3,), False, False), ): (_3d_hemitropic_3_vec, _3d_hemitropic_3_vec_outputs), } # future possible elements of the below table. # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),True,False), # TensorType(2,(3,3),True,False)) : (None,None), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True)) : (None,None), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True), # TensorType(2,(3,3),False,True)) : (None,None), # (TensorType(2,(3,3),False,True), # TensorType(2,(3,3),False,True), # TensorType(2,(3,3),False,True)) : (None,None), _3d_isotropic_three_input_table = { ( TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False), ): (_3d_isotropic_2s_rank_2_1_vec, _3d_isotropic_2s_rank_2_1_vec_outputs), ( TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False), ): (_3d_isotropic_1s_1a_rank_2_1_vec, _3d_isotropic_1s_1a_rank_2_1_vec_outputs), ( TensorType(2, (3, 3), False, True), TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False), ): (_3d_isotropic_2a_rank_2_1_vec, _3d_isotropic_2a_rank_2_1_vec_outputs), ( TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False), TensorType(1, (3,), False, False), ): (_3d_isotropic_1s_rank_2_2_vec, _3d_isotropic_1s_rank_2_2_vec_outputs), ( TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False), TensorType(1, (3,), False, False), ): (_3d_isotropic_1a_rank_2_2_vec, _3d_isotropic_1a_rank_2_2_vec_outputs), } # _3d_isotropic_four_input_table = { # '''(TensorType(2,(3,3),True,False), # TensorType(2,(3,3),True,False), # TensorType(1,(3,),False,False), # TensorType(1,(3,),False,False)) : (_3d_isotropic_2s_2_vec, # _3d_isotropic_2s_2_vec_outputs), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True), # TensorType(1,(3,),False,False), # TensorType(1,(3,),True,False)) : (_3d_isotropic_1s_1a_2_vec, # _3d_isotropic_1s_1a_2_vec_outputs), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True), # TensorType(1,(3,),False,False), # TensorType(1,(3,),False,False)) : (_3d_isotropic_2a_2_vec, # _3d_isotropic_2a_2_vec_outputs),''' # } # _3d_hemitropic_four_input_table = { # '''(TensorType(2,(3,3),True,False), # TensorType(2,(3,3),True,False), # TensorType(1,(3,),False,False), # TensorType(1,(3,),False,False)) : (None,None), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True), # TensorType(1,(3,),False,False), # TensorType(1,(3,),True,False)) : (None,None), # (TensorType(2,(3,3),True,False), # TensorType(2,(3,3),False,True), # TensorType(1,(3,),False,False), # TensorType(1,(3,),False,False)) : (None,None),''' # } _3d_isotropic_single_input_table = { # maps input types to invariant calculation functions # read "_3d_hemitroic_1s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the input is one symmetric rank-two tensor", and likewise with "_1a_" and antisymmetric (TensorType(2, (3, 3), True, False),): ( _3d_hemitropic_1s_rank_2, _3d_hemitropic_1s_rank_2_outputs, ), (TensorType(2, (3, 3), False, True),): ( _3d_hemitropic_1a_rank_2, _3d_hemitropic_1a_rank_2_outputs, ), (TensorType(1, (3,), False, False),): ( _3d_hemitropic_1_vec, _3d_hemitropic_1_vec_outputs, ), (TensorType(0, (), False, False),): ( lambda x, backend: x, {TensorType(0, (), False, False): lambda x, backend: 1.0}, ), } _3d_isotropic_1_vec_outputs = { # where they both exist, hemitropic and isotropic form-invariants are the same in this case TensorType(2, (3, 3), True, False): _3d_hmt_1vec_o_s2, TensorType(1, (3,), False, False): _3d_hmt_1vec_o_vec, } _3d_hemitropic_2s_rank_2_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_2sr2_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_2sr2_o_a2, TensorType(1, (3,), False, False): _3d_hmt_2sr2_o_vec, } _3d_hemitropic_1s_1a_rank_2_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_1s1ar2_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_1s1ar2_o_a2, TensorType(1, (3,), False, False): _3d_hmt_1s1ar2_o_vec, } _3d_hemitropic_2a_rank_2_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_2ar2_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_2ar2_o_a2, TensorType(1, (3,), False, False): _3d_hmt_2ar2_o_vec, } _3d_hemitropic_1s_rank_2_1_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_1sr21vec_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_1sr21vec_o_a2, TensorType(1, (3,), False, False): _3d_hmt_1sr21vec_o_vec, } _3d_hemitropic_1a_rank_2_1_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_1sr21vec_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_1sr21vec_o_a2, TensorType(1, (3,), False, False): _3d_hmt_1sr21vec_o_vec, } _3d_isotropic_1a_rank_2_1_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_iso_1sr21vec_o_s2, TensorType(2, (3, 3), False, True): _3d_iso_1sr21vec_o_a2, TensorType(1, (3,), False, False): _3d_iso_1sr21vec_o_vec, } _3d_hemitropic_2_vec_outputs = { TensorType(2, (3, 3), True, False): _3d_hmt_2vec_o_s2, TensorType(2, (3, 3), False, True): _3d_hmt_2vec_o_a2, TensorType(1, (3,), False, False): _3d_hmt_2vec_o_vec, } _3d_hemitropic_two_input_table = { # read "_3d_hemitroic_2s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the inputs are two symmetric rank-two tensors" (TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False)): ( _3d_hemitropic_2s_rank_2, _3d_hemitropic_2s_rank_2_outputs, ), (TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), False, True)): ( _3d_hemitropic_1s_1a_rank_2, _3d_hemitropic_1s_1a_rank_2_outputs, ), (TensorType(2, (3, 3), False, True), TensorType(2, (3, 3), False, True)): ( _3d_hemitropic_2a_rank_2, _3d_hemitropic_2a_rank_2_outputs, ), (TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False)): ( _3d_hemitropic_1s_rank_2_1_vec, _3d_hemitropic_1s_rank_2_1_vec_outputs, ), (TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False)): ( _3d_hemitropic_1a_rank_2_1_vec, _3d_hemitropic_1a_rank_2_1_vec_outputs, ), (TensorType(1, (3,), False, False), TensorType(1, (3,), False, False)): ( _3d_hemitropic_2_vec, _3d_hemitropic_2_vec_outputs, ), } _2d_isotropic_two_input_table = { (TensorType(2, (2, 2), True, False), TensorType(2, (2, 2), True, False)): ( _2d_isotropic_2s_rank_2, _2d_isotropic_2s_rank_2_outputs, ), (TensorType(2, (2, 2), False, True), TensorType(2, (2, 2), False, True)): ( _2d_isotropic_2a_rank_2, _2d_isotropic_2a_rank_2_outputs, ), (TensorType(2, (2, 2), True, False), TensorType(2, (2, 2), False, True)): ( _2d_isotropic_1s_1a_rank_2, _2d_isotropic_1s_1a_rank_2_outputs, ), (TensorType(2, (2, 2), True, False), TensorType(1, (2,), False, False)): ( _2d_isotropic_1s_rank_2_1_vec, _2d_isotropic_1s_rank_2_1_vec_outputs, ), (TensorType(2, (2, 2), False, True), TensorType(1, (2,), False, False)): ( _2d_isotropic_1a_rank_2_1_vec, _2d_isotropic_1a_rank_2_1_vec_outputs, ), (TensorType(1, (2,), False, False), TensorType(1, (2,), False, False)): ( _2d_isotropic_2_vec, _2d_isotropic_2_vec_outputs, ), (TensorType(3, (2, 2, 2), False, False), TensorType(3, (2, 2, 2), False, False)): ( _2d_isotropic_2_rank_3, _2d_isotropic_2_rank_3_outputs, ), (TensorType(3, (2, 2, 2), False, False), TensorType(1, (2,), False, False)): ( _2d_isotropic_1_rank_3_1_vec, _2d_isotropic_1_rank_3_1_vec_outputs, ), (TensorType(3, (2, 2, 2), False, False), TensorType(2, (2, 2), True, False)): ( _2d_isotropic_1_rank_3_1s_rank_2, _2d_isotropic_1_rank_3_1s_rank_2_outputs, ), } _3d_isotropic_two_input_table = { # read "_3d_hemitroic_2s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the inputs are two symmetric rank-two tensors" (TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), True, False)): ( _3d_hemitropic_2s_rank_2, _3d_hemitropic_2s_rank_2_outputs, ), (TensorType(2, (3, 3), True, False), TensorType(2, (3, 3), False, True)): ( _3d_hemitropic_1s_1a_rank_2, _3d_hemitropic_1s_1a_rank_2_outputs, ), (TensorType(2, (3, 3), False, True), TensorType(2, (3, 3), False, True)): ( _3d_hemitropic_2a_rank_2, _3d_hemitropic_2a_rank_2_outputs, ), (TensorType(2, (3, 3), True, False), TensorType(1, (3,), False, False)): ( _3d_isotropic_1s_rank_2_1_vec, _3d_isotropic_1s_rank_2_1_vec_outputs, ), (TensorType(2, (3, 3), False, True), TensorType(1, (3,), False, False)): ( _3d_isotropic_1a_rank_2_1_vec, _3d_isotropic_1a_rank_2_1_vec_outputs, ), (TensorType(1, (3,), False, False), TensorType(1, (3,), False, False)): ( _3d_isotropic_2_vec, _3d_isotropic_2_vec_outputs, ), } _3d_hemitropic_single_input_table = { # maps input types to invariant calculation functions # read "_3d_hemitroic_1s_rank_2" as "a function giving scalar invariants for SO(3) in 3-D where the input is one symmetric rank-two tensor", and likewise with "_1a_" and antisymmetric (TensorType(2, (3, 3), True, False),): ( _3d_hemitropic_1s_rank_2, _3d_hemitropic_1s_rank_2_outputs, ), (TensorType(2, (3, 3), False, True),): ( _3d_hemitropic_1a_rank_2, _3d_hemitropic_1a_rank_2_outputs, ), (TensorType(1, (3,), False, False),): ( _3d_hemitropic_1_vec, _3d_hemitropic_1_vec_outputs, ), (TensorType(0, (), False, False),): ( lambda x, backend: x, {TensorType(0, (), False, False): lambda x, backend: 1.0}, ), } _2d_isotropic_single_input_table = { (TensorType(2, (2, 2), True, False),): ( _2d_isotropic_1s_rank_2, _2d_isotropic_1s_rank_2_outputs, ), (TensorType(2, (2, 2), False, True),): ( _2d_isotropic_1a_rank_2, _2d_isotropic_1a_rank_2_outputs, ), (TensorType(1, (2,), False, False),): ( _2d_isotropic_1_vec, _2d_isotropic_1_vec_outputs, ), # doesn't matter if the rank-3 is totally symmetric/antisymmetric or not, they're all the same (TensorType(3, (2, 2, 2), False, False),): ( _2d_isotropic_1_rank_3, _2d_isotropic_1_rank_3_outputs, ), (TensorType(3, (2, 2, 2), True, False),): ( _2d_isotropic_1_rank_3, _2d_isotropic_1_rank_3_outputs, ), (TensorType(3, (2, 2, 2), False, True),): ( _2d_isotropic_1_rank_3, _2d_isotropic_1_rank_3_outputs, ), (TensorType(0, (), False, False),): ( lambda x, backend: x, {TensorType(0, (), False, False): lambda x, backend: [1.0]}, ), } _2d_hemitropic_single_input_table = { # some of the entries are the same for hemi vs isotropic (TensorType(2, (2, 2), True, False),): ( _2d_isotropic_1s_rank_2, _2d_hemitropic_1s_rank_2_outputs, ), (TensorType(2, (2, 2), False, True),): ( _2d_hemitropic_1a_rank_2, _2d_hemitropic_1a_rank_2_outputs, ), (TensorType(1, (2,), False, False),): ( _2d_isotropic_1_vec, _2d_hemitropic_1_vec_outputs, ), # doesn't matter if the rank-3 is totally symmetric/antisymmetric or not, they're all the same (TensorType(3, (2, 2, 2), False, False),): ( _2d_isotropic_1_rank_3, _2d_hemitropic_1_rank_3_outputs, ), (TensorType(3, (2, 2, 2), True, False),): ( _2d_isotropic_1_rank_3, _2d_hemitropic_1_rank_3_outputs, ), (TensorType(3, (2, 2, 2), False, True),): ( _2d_isotropic_1_rank_3, _2d_hemitropic_1_rank_3_outputs, ), # a scalar (TensorType(0, (), False, False),): ( lambda x, backend: x, {TensorType(0, (), False, False): lambda x, backend: 1.0}, ), } _3d_zero_input_table = {(): (_3d_empty, _3d_identity_outputs)} _2d_zero_input_table = {(): (_2d_empty, _2d_identity_outputs)} """ These tables map the number of inputs to the table above containing scalar invariant calculation functions and output form-invariant calculation tables """ _2d_isotropic_scalar_input_table = { 0: _2d_zero_input_table, 1: _2d_isotropic_single_input_table, 2: _2d_isotropic_two_input_table, # 3 : _2d_isotropic_three_input_table, # 4 : _2d_isotropic_four_input_table, } _2d_hemitropic_scalar_input_table = { 0: _2d_zero_input_table, 1: _2d_hemitropic_single_input_table, # 2 : _2d_isotropic_two_input_table, # 3 : _2d_isotropic_three_input_table, # 4 : _2d_isotropic_four_input_table, } _3d_hemitropic_scalar_input_table = { # maps number of inputs to (table mapping input types to invariant calculation functions) 0: _3d_zero_input_table, 1: _3d_hemitropic_single_input_table, 2: _3d_hemitropic_two_input_table, 3: _3d_hemitropic_three_input_table, # 4 : _3d_hemitropic_four_input_table, } _3d_isotropic_scalar_input_table = { # maps number of inputs to (table mapping input types to invariant calculation functions) 0: _3d_zero_input_table, 1: _3d_isotropic_single_input_table, 2: _3d_isotropic_two_input_table, 3: _3d_isotropic_three_input_table, # 4 : _3d_isotropic_four_input_table, } """ For each supported group, these tables map that group to a table (above), which maps the number of inputs to a third table mapping input types to invariant calculation functions and an output fucntion table """ _3d_invariant_group_table = { # maps symmetry group to (table mapping number of inputs to (table mapping input types to invariant calculation functions)) "O(3)": _3d_isotropic_scalar_input_table, "SO(3)": _3d_hemitropic_scalar_input_table, } _2d_invariant_group_table = { # maps symmetry group to (table mapping number of inputs to (table mapping input types to invariant calculation functions)) "O(2)": _2d_isotropic_scalar_input_table, "SO(2)": _2d_hemitropic_scalar_input_table, } """ The first layer of the hash table onion. Maps the number of spatial dimensions to a group table """ _scalar_invariant_function_table = { # maps spatial dimension to (table mapping symmetry group to (table mapping number of inputs to (table mapping input types to invariant calculation functions))) 2: _2d_invariant_group_table, 3: _3d_invariant_group_table, }