Source code for crikit.cr.torch_utils

from crikit.covering import Covering, register_covering
from crikit.covering.ufl import get_numpy_shape
from crikit.cr.space_builders import DirectSum
from crikit.cr.types import PointMap, Space
from crikit.cr.ufl import UFLExprSpace, UFLFunctionSpace
from crikit.cr.quadrature import get_quadrature_params, make_quadrature_spaces
from crikit.fe import dx
from crikit.fe_adjoint import Function
from crikit.projection import project
from pyadjoint_utils.fenics_adjoint import function_get_local, function_set_local
from pyadjoint_utils import ReducedFunctionNumPy, ReducedFunction
from pyadjoint.tape import (
    no_annotations,
    annotate_tape,
    get_working_tape,
    stop_annotating,
)
from pyadjoint_utils.torch_adjoint import (
    Tensor,
    tensor,
    overload_torch,
    get_default_dtype,
    to_numpy,
    to_cpu_numpy,
    to_torch,
)
from pyadjoint.enlisting import Enlist
from pyadjoint_utils.convert import make_convert_block
from pyadjoint.overloaded_function import overload_function
import torch
from functools import partial


[docs]class TorchTensor(Space): def __init__(self, shape, dtype=None): self._shape = tuple(shape) self._dtype = dtype if dtype is not None else get_default_dtype() self._indefinite_axes = torch.tensor(shape) < 0 self._definite_axes = torch.logical_not(self._indefinite_axes) self._definite_shape = torch.tensor(shape)[self._definite_axes] @property def shape(self): return self._shape def shape(self): return self._shape def is_point(self, x): if not isinstance(x, (Tensor, torch.Tensor)): return False if len(x.shape) != len(self._shape): return False return torch.equal( torch.tensor(x.shape)[self._definite_axes], self._definite_shape ) def point(self, **kwargs): indefinite_ax_size = kwargs.get("indefinite_ax_size", 10) return torch.ones(*[s if s != -1 else indefinite_ax_size for s in self._shape]) def __eq__(self, other): return ( isinstance(other, TorchTensor) and self._shape == other._shape and self._dtype == other._dtype ) def __repr__(self): return f"TorchTensor({self._shape}, dtype={self._dtype})"
class Torch_To_UFLFunctionSpace(PointMap): def __init__( self, source, target, quad_space=None, quad_params=None, make_block=True, gpu_fenics=False, ): if quad_space is None: quad_space, quad_params = make_quadrature_spaces( target, quad_params=quad_params ) self._quad_space = quad_space self._dx = dx(metadata=quad_params) if isinstance(target, UFLFunctionSpace): self._target_space = target._functionspace else: self._target_space = quad_space if source is None: source = TorchTensor(get_numpy_shape(self._quad_space)) self._tonumpy = to_numpy if gpu_fenics else to_cpu_numpy super().__init__(source, target) def __call__(self, x, **kwargs): q = function_set_local(Function(self._quad_space), self._tonumpy(x)) if self._target_space == self._quad_space: return q return project(q, self._target_space, dx=self._dx) class TorchFunction(PointMap): def __init__( self, f, in_shapes, out_shapes, in_dtypes=None, out_dtypes=None, **kwargs ): self._in_shapes, self._n_in = self._get_shapes(in_shapes) self._out_shapes, self._n_out = self._get_shapes(out_shapes) super().__init__( self._build_space(self._in_shapes, self._get_dtypes(self._n_in, in_dtypes)), self._build_space( self._out_shapes, self._get_dtypes(self._n_out, out_dtypes) ), ) self._f = overload_torch(f, **kwargs) def __call__(self, x): if self._n_in <= 1: return self._f(x) return self._f(*x) def _build_space(self, shapes, dtypes): spaces = [TorchTensor(s, dtype=d) for s, d in zip(shapes, dtypes)] if len(spaces) == 1: return spaces[0] return DirectSum(spaces) def _get_dtypes(self, n, dtypes): if dtypes is None: return [None] * n else: if isinstance(dtypes, (list, tuple)): if len(dtypes) != n: raise ValueError( "Must supply either a single dtype for " " all inputs/outputs or one for each " "input/output tensor" ) return dtypes return [dtypes] def _get_shapes(self, shp): if len(shp) == 0: return [()], 0 if isinstance(shp[0], int): return [shp], 1 else: return shp, len(shp) class UFLExprSpace_To_Torch(PointMap): def __init__( self, source, target=None, quad_space=None, quad_params=None, make_block=True, gpu_fenics=False, ): if quad_space is None: quad_space, quad_params = make_quadrature_spaces( target, quad_params=quad_params ) self._quad_space = quad_space self._dx = dx(metadata=quad_params) if isinstance(target, UFLFunctionSpace): self._target_space = target._functionspace else: self._target_space = quad_space if target is None: target = TorchTensor(get_numpy_shape(self._quad_space)) self._to_torch = partial(to_torch, to_cpu=(not gpu_fenics)) super().__init__(source, target) def __call__(self, expr, **kwargs): if expr.ufl_shape != self._quad_space.ufl_element().value_shape(): raise ValueError( f"Expression shape {expr.ufl_shape} does not match the target shape of {self._quad_space.ufl_element().value_shape()}" ) if isinstance(expr, Function) and self._quad_space == expr.function_space(): func = expr else: func = project(expr, self._quad_space, dx=self._dx) return self._to_torch(function_get_local(func)) @register_covering(UFLExprSpace, TorchTensor) @register_covering(UFLFunctionSpace, TorchTensor) class Torch_UFLFunctionSpace_Covering(Covering): def __init__( self, base_space, covering_space=None, domain=None, quad_params=None, **kwargs ): quad_space, quad_params = make_quadrature_spaces( base_space, quad_params=quad_params, domain=domain ) self._quad_space = quad_space self._quad_params = quad_params if covering_space is None: covering_space = TorchTensor(get_numpy_shape(self._quad_space)) super().__init__(base_space, covering_space, **kwargs) def covering_map(self): return Torch_To_UFLFunctionSpace( self._covering_space, self._base_space, quad_space=self._quad_space, quad_params=self._quad_params, ) def section_map(self): return UFLExprSpace_To_Torch( self._base_space, self._covering_space, quad_space=self._quad_space, quad_params=self._quad_params, ) class TorchFunctionJITTracer: class StaticArgnumHandler: def __init__(self, tracer, static_argnums, argnums=None): self.tracer = tracer self.argnums = static_argnums self.static_arg_ids = [None] * len(self.argnums) self.static_arg_hashes = [None] * len(self.argnums) self.traced = None self.overload_kwargs = {"argnums": argnums} self.__name__ = self.tracer.f.__name__ def _decide_retrace(self, args): new_arg_ids = [id(arg) for arg in args] retrace = new_arg_ids != self.static_arg_ids if retrace: self.static_arg_ids = new_arg_ids self.static_arg_hashes = list(map(hash, args)) else: new_arg_hashes = list(map(hash, args)) if new_arg_hashes != self.static_arg_hashes: retrace = True self.static_arg_hashes = new_arg_hashes return retrace def __call__(self, *args, **kwargs): if self._decide_retrace(args) or (self.traced is None): if "static_argnums" in self.overload_kwargs: self.overload_kwargs.pop("static_argnums") self.traced = self.tracer.trace_and_overload( *args, argnums=self.overload_kwargs["argnums"] ) return self.traced(*args, **kwargs) def __init__(self, f, strict=False, name=None): self.f = f self.strict = strict self.__name__ = name or self.f.__name__ def trace(self, *args, strict=None, check=False, **kwargs): if "static_argnums" in kwargs: static_argnums = kwargs.pop("static_argnums") return self.StaticArgnumHandler( self, static_argnums, argnums=kwargs.get("argnums", list(range(len(args)))), ) if strict is not None: self.strict = strict # return torch.jit.script(self.f) return torch.jit.trace(self.f, args, strict=self.strict, check_trace=check) def trace_and_overload( self, args, jit=True, argnums=None, pointwise=False, out_pointwise=False, strict=False, check=False, **kwargs, ): return overload_torch( self.trace(*args, strict=strict, check=check, argnums=argnums) if jit else self.f, jit=False, argnums=argnums, pointwise=pointwise, out_pointwise=out_pointwise, )