Source code for crikit.invariants.utils

# from jax import jit
# import jax.numpy as np
import numpy as onp
from functools import partial
from itertools import chain, combinations
from pyadjoint_utils import get_default_backend

backend = get_default_backend()


def set_backend(new_backend):
    global backend
    backend = new_backend


def get_backend():
    global backend
    return backend


[docs]def symm(x, backend): """Symmetrizes the input :param x: a 2-d array to symmetrize :type x: Union[np.ndarray,onp.ndarray] :return: A symmetric (and doubled) version of ``x`` :rtype: Union[np.ndarray,onp.ndarray] """ return x + x.T
[docs]def antisymm(x, backend): """Antisymmetrizes the input :param x: a 2-d array to antisymmetrize :type x: Union[np.ndarray,onp.ndarray] :return: An antisymmetric (and doubled) version of ``x`` :rtype: Union[np.ndarray,onp.ndarray] """ return x - x.T
def commutator_action(A, B, v, backend): bv = B @ v av = A @ v return A @ bv - B @ av def anticommutator_action(A, B, v, backend): Av = A @ v Bv = B @ v return A @ Bv + B @ Av def scalar_triple_prod(u, v, w, backend): return backend.dot(v, backend.cross(u, w)) # Copied from https://docs.python.org/3/library/itertools.html def powerset(iterable, exclude_empty_set=True): "powerset([1,2,3], False) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" s = tuple(iterable) return chain.from_iterable( combinations(s, r) for r in range(int(exclude_empty_set), len(s) + 1) )
[docs]def levi_civita(n, backend): """Returns the Levi-Civita pseudotensor in ``n`` dimensions. :param n: the number of dimensions :type n: int :returns: The Levi-Civita pseudotensor in ``n`` spatial dimensions :rtype: np.ndarray """ if n == 2: return backend.array([[0, 1], [-1, 0]]) elif n == 3: eps = onp.zeros((3, 3, 3)) pos_idx = [(0, 1, 2), (1, 2, 0), (2, 0, 1)] neg_idx = [(2, 1, 0), (0, 2, 1), (1, 0, 2)] for ix, iy, iz in pos_idx: eps[ix, iy, iz] = 1 for ix, iy, iz in neg_idx: eps[ix, iy, iz] = -1 return backend.array(eps) else: raise NotImplementedError
def _eps_vec_action(v, backend): return backend.einsum("ijk,k -> ij", levi_civita(3, backend), v) def axial_vector(G, backend): return backend.einsum("ijk,jk->i", levi_civita(3, backend), G, optimize=True) def _tprod(x, y, backend): return backend.tensordot(x, y, axes=0) def symm_q4(T, backend): S = backend.zeros_like(T) for i in range(3): for j in range(3): for k in range(3): for l in range(3): idx = (i, j, k, l) for pi in permutations(idx): S = backend.index_update(S, idx, S[idx] + T[tuple(pi)]) return S def symm_q3(T, backend): return ( T + backend.moveaxis(T, [0, 1, 2], [1, 2, 0]) + backend.moveaxis(T, [0, 1, 2], [2, 0, 1]) ) def tA(T, A, backend): return backend.einsum("ijk,jk -> i", T, A, optimize=True) def tv(T, v, backend): return tA(T, _tprod(v, v)) def Tv(T, v, backend): return backend.einsum("ijk,k -> ij", T, v, optimize=True) def TinnerS(T, S, backend): return backend.einsum("ijk,jkl -> il", T, S, optimize=True) def TcontrS(T, S, backend): return backend.einsum("ijk,ijk", T, S, optimize=True) def TW(T, W, backend): return backend.einsum("ijm,mk -> ijk", T, W, optimize=True) def tbrace(T, backend): # eq 2.34 in Zheng '94 return [ T, backend.moveaxis(T, [1, 2, 3], [2, 3, 1]), backend.moveaxis(T, [1, 2, 3], [3, 1, 2]), ] def near(y, x, tol, backend): return backend.abs(x - y) <= tol def _is_third_order_irreducible(T, backend): if len(T.shape) != 3: return False t1 = T[1, 1, 1] t2 = T[2, 2, 2] t1id = (1, 1, 1) t2id = (2, 2, 2) t1idx = {(1, 2, 2), (2, 1, 2), (2, 2, 1)} t2idx = {(2, 1, 1), (1, 2, 1), (1, 1, 2)} for i in range(t.shape[0]): for j in range(t.shape[1]): for k in range(t.shape[2]): idx = (i, j, k) if idx == t1id or idx == t2id: continue if idx in t1idx: if not near(T[idx], -t1, 1.0e-10, backend): return False elif idx in t2idx: if not near(T[idx], -t2, 1.0e-10, backend): return False else: if not near(T[idx], 0.0, 1.0e-10, backend): return False return True def spectral_decomp(A, hermitian=False, backend=None): if backend is None: backend = get_default_backend() w, V = backend.linalg.eigh(A) if hermitian else backend.linalg.eig(A) return V, backend.diag(w) def factorial(n: int, backend): if n == 0: return 1 return backend.prod(backend.arange(1, n + 1)) # @partial(jit,static_argnums=(2,)) def matpows(A, ns, hermitian=False, backend=None): # TODO: use a better algorithm for this, maybe exponentiation by squaring? V, D = spectral_decomp(A, hermitian) Vinv = backend.linalg.inv(V) return [Vinv @ backend.power(D, i) @ V for i in ns] def matexp(A, hermitian=False, backend=None): # TODO: find actually good choices of these parameters N = range(10) return sum( [ A_i / factorial(i, backend) for A_i, i in zip(matpows(A, N, hermitian, backend), N) ] ) def _3d_rotation_matrix(axis, theta, backend=None): if backend is None: backend = get_default_backend() ct = backend.cos(theta) st = backend.sin(theta) ux = axis[0] uy = axis[1] uz = axis[2] return backend.array( [ [ ct + ux * ux * (1 - ct), ux * uy * (1 - ct) - uz * st, ux * uz * (1 - ct) + uy * st, ], [ uy * ux * (1 - ct) + uz * st, ct + uy * uy * (1 - ct), uy * uz * (1 - ct) - ux * st, ], [ uz * ux * (1 - ct) - uy * st, uz * uy * (1 - ct) + ux * st, ct + uz * uz * (1 - ct), ], ] ) def householder_matrix(v, backend): return backend.eye(v.size) - _tprod(v, v, backend)
[docs]def near(val, to, rtol=1.0e-5, backend=None): """Returns True if ``val`` and ``to`` are within relative tolerance ``rtol`` and False otherwise :param val: a value :type val: backend.ndarray :param to: is ``val`` close to this? :type to: np.ndarray :param rtol: Relative tolerance, defaults to 1.0e-5 :type rtol: float, optional :return: are ``val`` and ``to`` within ``rtol``? :rtype: bool """ if backend is None: backend = get_default_backend() # effectively a JAX DeviceArray( ,dtype=bool) to bool conversion return bool(backend.allclose(val, to, rtol=rtol))
def is_symm(X, rtol=1.0e-5, backend=None): if backend is None: backend = get_default_backend() return near(X, 0.5 * symm(X, backend), rtol, backend) def is_antisymm(X, rtol=1.0e-5, backend=None): return near(X, 0.5 * antisymm(X, backend), rtol, backend)