# from jax import jit
# import jax.numpy as np
import numpy as onp
from functools import partial
from itertools import chain, combinations
from pyadjoint_utils import get_default_backend
backend = get_default_backend()
def set_backend(new_backend):
global backend
backend = new_backend
def get_backend():
global backend
return backend
[docs]def symm(x, backend):
"""Symmetrizes the input
:param x: a 2-d array to symmetrize
:type x: Union[np.ndarray,onp.ndarray]
:return: A symmetric (and doubled) version of ``x``
:rtype: Union[np.ndarray,onp.ndarray]
"""
return x + x.T
[docs]def antisymm(x, backend):
"""Antisymmetrizes the input
:param x: a 2-d array to antisymmetrize
:type x: Union[np.ndarray,onp.ndarray]
:return: An antisymmetric (and doubled) version of ``x``
:rtype: Union[np.ndarray,onp.ndarray]
"""
return x - x.T
def commutator_action(A, B, v, backend):
bv = B @ v
av = A @ v
return A @ bv - B @ av
def anticommutator_action(A, B, v, backend):
Av = A @ v
Bv = B @ v
return A @ Bv + B @ Av
def scalar_triple_prod(u, v, w, backend):
return backend.dot(v, backend.cross(u, w))
# Copied from https://docs.python.org/3/library/itertools.html
def powerset(iterable, exclude_empty_set=True):
"powerset([1,2,3], False) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = tuple(iterable)
return chain.from_iterable(
combinations(s, r) for r in range(int(exclude_empty_set), len(s) + 1)
)
[docs]def levi_civita(n, backend):
"""Returns the Levi-Civita pseudotensor in ``n`` dimensions.
:param n: the number of dimensions
:type n: int
:returns: The Levi-Civita pseudotensor in ``n`` spatial dimensions
:rtype: np.ndarray
"""
if n == 2:
return backend.array([[0, 1], [-1, 0]])
elif n == 3:
eps = onp.zeros((3, 3, 3))
pos_idx = [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
neg_idx = [(2, 1, 0), (0, 2, 1), (1, 0, 2)]
for ix, iy, iz in pos_idx:
eps[ix, iy, iz] = 1
for ix, iy, iz in neg_idx:
eps[ix, iy, iz] = -1
return backend.array(eps)
else:
raise NotImplementedError
def _eps_vec_action(v, backend):
return backend.einsum("ijk,k -> ij", levi_civita(3, backend), v)
def axial_vector(G, backend):
return backend.einsum("ijk,jk->i", levi_civita(3, backend), G, optimize=True)
def _tprod(x, y, backend):
return backend.tensordot(x, y, axes=0)
def symm_q4(T, backend):
S = backend.zeros_like(T)
for i in range(3):
for j in range(3):
for k in range(3):
for l in range(3):
idx = (i, j, k, l)
for pi in permutations(idx):
S = backend.index_update(S, idx, S[idx] + T[tuple(pi)])
return S
def symm_q3(T, backend):
return (
T
+ backend.moveaxis(T, [0, 1, 2], [1, 2, 0])
+ backend.moveaxis(T, [0, 1, 2], [2, 0, 1])
)
def tA(T, A, backend):
return backend.einsum("ijk,jk -> i", T, A, optimize=True)
def tv(T, v, backend):
return tA(T, _tprod(v, v))
def Tv(T, v, backend):
return backend.einsum("ijk,k -> ij", T, v, optimize=True)
def TinnerS(T, S, backend):
return backend.einsum("ijk,jkl -> il", T, S, optimize=True)
def TcontrS(T, S, backend):
return backend.einsum("ijk,ijk", T, S, optimize=True)
def TW(T, W, backend):
return backend.einsum("ijm,mk -> ijk", T, W, optimize=True)
def tbrace(T, backend):
# eq 2.34 in Zheng '94
return [
T,
backend.moveaxis(T, [1, 2, 3], [2, 3, 1]),
backend.moveaxis(T, [1, 2, 3], [3, 1, 2]),
]
def near(y, x, tol, backend):
return backend.abs(x - y) <= tol
def _is_third_order_irreducible(T, backend):
if len(T.shape) != 3:
return False
t1 = T[1, 1, 1]
t2 = T[2, 2, 2]
t1id = (1, 1, 1)
t2id = (2, 2, 2)
t1idx = {(1, 2, 2), (2, 1, 2), (2, 2, 1)}
t2idx = {(2, 1, 1), (1, 2, 1), (1, 1, 2)}
for i in range(t.shape[0]):
for j in range(t.shape[1]):
for k in range(t.shape[2]):
idx = (i, j, k)
if idx == t1id or idx == t2id:
continue
if idx in t1idx:
if not near(T[idx], -t1, 1.0e-10, backend):
return False
elif idx in t2idx:
if not near(T[idx], -t2, 1.0e-10, backend):
return False
else:
if not near(T[idx], 0.0, 1.0e-10, backend):
return False
return True
def spectral_decomp(A, hermitian=False, backend=None):
if backend is None:
backend = get_default_backend()
w, V = backend.linalg.eigh(A) if hermitian else backend.linalg.eig(A)
return V, backend.diag(w)
def factorial(n: int, backend):
if n == 0:
return 1
return backend.prod(backend.arange(1, n + 1))
# @partial(jit,static_argnums=(2,))
def matpows(A, ns, hermitian=False, backend=None):
# TODO: use a better algorithm for this, maybe exponentiation by squaring?
V, D = spectral_decomp(A, hermitian)
Vinv = backend.linalg.inv(V)
return [Vinv @ backend.power(D, i) @ V for i in ns]
def matexp(A, hermitian=False, backend=None):
# TODO: find actually good choices of these parameters
N = range(10)
return sum(
[
A_i / factorial(i, backend)
for A_i, i in zip(matpows(A, N, hermitian, backend), N)
]
)
def _3d_rotation_matrix(axis, theta, backend=None):
if backend is None:
backend = get_default_backend()
ct = backend.cos(theta)
st = backend.sin(theta)
ux = axis[0]
uy = axis[1]
uz = axis[2]
return backend.array(
[
[
ct + ux * ux * (1 - ct),
ux * uy * (1 - ct) - uz * st,
ux * uz * (1 - ct) + uy * st,
],
[
uy * ux * (1 - ct) + uz * st,
ct + uy * uy * (1 - ct),
uy * uz * (1 - ct) - ux * st,
],
[
uz * ux * (1 - ct) - uy * st,
uz * uy * (1 - ct) + ux * st,
ct + uz * uz * (1 - ct),
],
]
)
def householder_matrix(v, backend):
return backend.eye(v.size) - _tprod(v, v, backend)
[docs]def near(val, to, rtol=1.0e-5, backend=None):
"""Returns True if ``val`` and ``to`` are within relative tolerance ``rtol``
and False otherwise
:param val: a value
:type val: backend.ndarray
:param to: is ``val`` close to this?
:type to: np.ndarray
:param rtol: Relative tolerance, defaults to 1.0e-5
:type rtol: float, optional
:return: are ``val`` and ``to`` within ``rtol``?
:rtype: bool
"""
if backend is None:
backend = get_default_backend()
# effectively a JAX DeviceArray( ,dtype=bool) to bool conversion
return bool(backend.allclose(val, to, rtol=rtol))
def is_symm(X, rtol=1.0e-5, backend=None):
if backend is None:
backend = get_default_backend()
return near(X, 0.5 * symm(X, backend), rtol, backend)
def is_antisymm(X, rtol=1.0e-5, backend=None):
return near(X, 0.5 * antisymm(X, backend), rtol, backend)