from crikit.covering import Covering, register_covering
from crikit.covering.ufl import get_numpy_shape
from crikit.cr.space_builders import DirectSum
from crikit.cr.types import PointMap, Space
from crikit.cr.ufl import UFLExprSpace, UFLFunctionSpace
from crikit.cr.quadrature import get_quadrature_params, make_quadrature_spaces
from crikit.fe import dx
from crikit.fe_adjoint import Function
from crikit.projection import project
from pyadjoint_utils.fenics_adjoint import function_get_local, function_set_local
from pyadjoint_utils import ReducedFunctionNumPy, ReducedFunction
from pyadjoint.tape import (
no_annotations,
annotate_tape,
get_working_tape,
stop_annotating,
)
from pyadjoint_utils.torch_adjoint import (
Tensor,
tensor,
overload_torch,
get_default_dtype,
to_numpy,
to_cpu_numpy,
to_torch,
)
from pyadjoint.enlisting import Enlist
from pyadjoint_utils.convert import make_convert_block
from pyadjoint.overloaded_function import overload_function
import torch
from functools import partial
[docs]class TorchTensor(Space):
def __init__(self, shape, dtype=None):
self._shape = tuple(shape)
self._dtype = dtype if dtype is not None else get_default_dtype()
self._indefinite_axes = torch.tensor(shape) < 0
self._definite_axes = torch.logical_not(self._indefinite_axes)
self._definite_shape = torch.tensor(shape)[self._definite_axes]
@property
def shape(self):
return self._shape
def shape(self):
return self._shape
def is_point(self, x):
if not isinstance(x, (Tensor, torch.Tensor)):
return False
if len(x.shape) != len(self._shape):
return False
return torch.equal(
torch.tensor(x.shape)[self._definite_axes], self._definite_shape
)
def point(self, **kwargs):
indefinite_ax_size = kwargs.get("indefinite_ax_size", 10)
return torch.ones(*[s if s != -1 else indefinite_ax_size for s in self._shape])
def __eq__(self, other):
return (
isinstance(other, TorchTensor)
and self._shape == other._shape
and self._dtype == other._dtype
)
def __repr__(self):
return f"TorchTensor({self._shape}, dtype={self._dtype})"
class Torch_To_UFLFunctionSpace(PointMap):
def __init__(
self,
source,
target,
quad_space=None,
quad_params=None,
make_block=True,
gpu_fenics=False,
):
if quad_space is None:
quad_space, quad_params = make_quadrature_spaces(
target, quad_params=quad_params
)
self._quad_space = quad_space
self._dx = dx(metadata=quad_params)
if isinstance(target, UFLFunctionSpace):
self._target_space = target._functionspace
else:
self._target_space = quad_space
if source is None:
source = TorchTensor(get_numpy_shape(self._quad_space))
self._tonumpy = to_numpy if gpu_fenics else to_cpu_numpy
super().__init__(source, target)
def __call__(self, x, **kwargs):
q = function_set_local(Function(self._quad_space), self._tonumpy(x))
if self._target_space == self._quad_space:
return q
return project(q, self._target_space, dx=self._dx)
class TorchFunction(PointMap):
def __init__(
self, f, in_shapes, out_shapes, in_dtypes=None, out_dtypes=None, **kwargs
):
self._in_shapes, self._n_in = self._get_shapes(in_shapes)
self._out_shapes, self._n_out = self._get_shapes(out_shapes)
super().__init__(
self._build_space(self._in_shapes, self._get_dtypes(self._n_in, in_dtypes)),
self._build_space(
self._out_shapes, self._get_dtypes(self._n_out, out_dtypes)
),
)
self._f = overload_torch(f, **kwargs)
def __call__(self, x):
if self._n_in <= 1:
return self._f(x)
return self._f(*x)
def _build_space(self, shapes, dtypes):
spaces = [TorchTensor(s, dtype=d) for s, d in zip(shapes, dtypes)]
if len(spaces) == 1:
return spaces[0]
return DirectSum(spaces)
def _get_dtypes(self, n, dtypes):
if dtypes is None:
return [None] * n
else:
if isinstance(dtypes, (list, tuple)):
if len(dtypes) != n:
raise ValueError(
"Must supply either a single dtype for "
" all inputs/outputs or one for each "
"input/output tensor"
)
return dtypes
return [dtypes]
def _get_shapes(self, shp):
if len(shp) == 0:
return [()], 0
if isinstance(shp[0], int):
return [shp], 1
else:
return shp, len(shp)
class UFLExprSpace_To_Torch(PointMap):
def __init__(
self,
source,
target=None,
quad_space=None,
quad_params=None,
make_block=True,
gpu_fenics=False,
):
if quad_space is None:
quad_space, quad_params = make_quadrature_spaces(
target, quad_params=quad_params
)
self._quad_space = quad_space
self._dx = dx(metadata=quad_params)
if isinstance(target, UFLFunctionSpace):
self._target_space = target._functionspace
else:
self._target_space = quad_space
if target is None:
target = TorchTensor(get_numpy_shape(self._quad_space))
self._to_torch = partial(to_torch, to_cpu=(not gpu_fenics))
super().__init__(source, target)
def __call__(self, expr, **kwargs):
if expr.ufl_shape != self._quad_space.ufl_element().value_shape():
raise ValueError(
f"Expression shape {expr.ufl_shape} does not match the target shape of {self._quad_space.ufl_element().value_shape()}"
)
if isinstance(expr, Function) and self._quad_space == expr.function_space():
func = expr
else:
func = project(expr, self._quad_space, dx=self._dx)
return self._to_torch(function_get_local(func))
@register_covering(UFLExprSpace, TorchTensor)
@register_covering(UFLFunctionSpace, TorchTensor)
class Torch_UFLFunctionSpace_Covering(Covering):
def __init__(
self, base_space, covering_space=None, domain=None, quad_params=None, **kwargs
):
quad_space, quad_params = make_quadrature_spaces(
base_space, quad_params=quad_params, domain=domain
)
self._quad_space = quad_space
self._quad_params = quad_params
if covering_space is None:
covering_space = TorchTensor(get_numpy_shape(self._quad_space))
super().__init__(base_space, covering_space, **kwargs)
def covering_map(self):
return Torch_To_UFLFunctionSpace(
self._covering_space,
self._base_space,
quad_space=self._quad_space,
quad_params=self._quad_params,
)
def section_map(self):
return UFLExprSpace_To_Torch(
self._base_space,
self._covering_space,
quad_space=self._quad_space,
quad_params=self._quad_params,
)
class TorchFunctionJITTracer:
class StaticArgnumHandler:
def __init__(self, tracer, static_argnums, argnums=None):
self.tracer = tracer
self.argnums = static_argnums
self.static_arg_ids = [None] * len(self.argnums)
self.static_arg_hashes = [None] * len(self.argnums)
self.traced = None
self.overload_kwargs = {"argnums": argnums}
self.__name__ = self.tracer.f.__name__
def _decide_retrace(self, args):
new_arg_ids = [id(arg) for arg in args]
retrace = new_arg_ids != self.static_arg_ids
if retrace:
self.static_arg_ids = new_arg_ids
self.static_arg_hashes = list(map(hash, args))
else:
new_arg_hashes = list(map(hash, args))
if new_arg_hashes != self.static_arg_hashes:
retrace = True
self.static_arg_hashes = new_arg_hashes
return retrace
def __call__(self, *args, **kwargs):
if self._decide_retrace(args) or (self.traced is None):
if "static_argnums" in self.overload_kwargs:
self.overload_kwargs.pop("static_argnums")
self.traced = self.tracer.trace_and_overload(
*args, argnums=self.overload_kwargs["argnums"]
)
return self.traced(*args, **kwargs)
def __init__(self, f, strict=False, name=None):
self.f = f
self.strict = strict
self.__name__ = name or self.f.__name__
def trace(self, *args, strict=None, check=False, **kwargs):
if "static_argnums" in kwargs:
static_argnums = kwargs.pop("static_argnums")
return self.StaticArgnumHandler(
self,
static_argnums,
argnums=kwargs.get("argnums", list(range(len(args)))),
)
if strict is not None:
self.strict = strict
# return torch.jit.script(self.f)
return torch.jit.trace(self.f, args, strict=self.strict, check_trace=check)
def trace_and_overload(
self,
args,
jit=True,
argnums=None,
pointwise=False,
out_pointwise=False,
strict=False,
check=False,
**kwargs,
):
return overload_torch(
self.trace(*args, strict=strict, check=check, argnums=argnums)
if jit
else self.f,
jit=False,
argnums=argnums,
pointwise=pointwise,
out_pointwise=out_pointwise,
)