from .block_variable import BlockVariable
from .tape import get_working_tape
_overloaded_types = {}
def get_overloaded_class(backend_class):
return _overloaded_types[backend_class]
def create_overloaded_object(obj, suppress_warning=False):
"""Creates an OverloadedType instance corresponding `obj`.
If an OverloadedType corresponding to `obj` has not been registered
through `register_overloaded_type`, a RuntimeWarning will be issued.
Args:
obj (object): The object to create an overloaded instance from.
suppress_warning (bool, optional): When set to True,
suppresses warning message when a suitable overloaded class is not found.
Default False.
Returns:
OverloadedType
"""
if isinstance(obj, OverloadedType):
return obj
obj_type = type(obj)
if obj_type in _overloaded_types:
overloaded_type = _overloaded_types[obj_type]
return overloaded_type._ad_init_object(obj)
else:
if not suppress_warning:
import warnings
warnings.warn("Could not find overloaded class of type '{}'.".format(obj_type), stacklevel=2)
return obj
def register_overloaded_type(overloaded_type, classes=None):
"""Register an overloaded type for use in `create_overloaded_object`
Overloaded types used with this function should have implemented a classmethod `_ad_create_object`.
For usage as a decorator, OverloadedType should be the first base of `overloaded_type`, and `classes`
the second base.
Args:
overloaded_type (type): The OverloadedType subclass to register.
classes (type, tuple, optional): The original class/classes that this OverloadedType subclass
overloads.
Returns:
type: returns only `overloaded_type` such that it can be used as a decorator.
"""
if isinstance(classes, (tuple, list)):
for cl in classes:
register_overloaded_type(overloaded_type, classes=cl)
else:
if classes is None:
classes = overloaded_type.__bases__[1]
_overloaded_types[classes] = overloaded_type
return overloaded_type
[docs]class OverloadedType(object):
"""Base class for OverloadedType types.
The purpose of each OverloadedType is to extend a type such that
it can be referenced by blocks as well as overload basic mathematical
operations such as __mul__, __add__, where they are needed.
"""
def __init__(self, *args, **kwargs):
self.block_variable = None
self.create_block_variable()
@classmethod
def _ad_init_object(cls, obj):
"""This method will often need to be overridden.
The method should implement a way to reconstruct a new overloaded instance
from a (possibly) not-overloaded instance.
Args:
obj: An instance of the original type
Returns:
OverloadedType: An overloaded instance which is considered the same as `obj`.
"""
return cls(obj)
def create_block_variable(self):
self.block_variable = BlockVariable(self)
return self.block_variable
def _ad_convert_type(self, value, options={}):
"""This method must be overridden.
Should implement a way to convert the result of an adjoint computation, `value`,
into the same type as `self`.
Args:
value (Any): The value to convert. Should be a result of an adjoint computation.
options (dict): A dictionary with options that may be supplied by the user.
If the convert type functionality offers some options on how to convert,
this is the dictionary that should be used.
For an example see fenics_adjoint.types.Function
Returns:
OverloadedType: An instance of the same type as `self`.
"""
raise NotImplementedError(f"OverloadedType._ad_convert_type not defined for class {type(self)}.")
def _ad_create_checkpoint(self):
"""This method must be overridden.
Should implement a way to create a checkpoint for the overloaded object.
The checkpoint should be returned and possible to restore from in the
corresponding _ad_restore_at_checkpoint method.
Returns:
:obj:`object`: A checkpoint. Could be of any type, but must be possible
to restore an object from that point.
"""
raise NotImplementedError
def _ad_restore_at_checkpoint(self, checkpoint):
"""This method must be overridden.
Should implement a way to restore the object at supplied checkpoint.
The checkpoint is created from the _ad_create_checkpoint method.
Returns:
:obj:`OverloadedType`: The object with same state as at the supplied checkpoint.
"""
raise NotImplementedError
def _ad_mul(self, other):
"""This method must be overridden.
The method should implement a routine for multiplying the overloaded object
with another object, and return an object of the same type as `self`.
Args:
other (:obj:`object`): The object to be multiplied with this.
Should at the very least accept :obj:`float` and :obj:`integer` objects.
Returns:
:obj:`OverloadedType`: The product of the two objects represented as
an instance of the same subclass of :class:`OverloadedType` as the type
of `self`.
"""
raise NotImplementedError
def _ad_imul(self, other):
"""In-place multiplies `self` with `other`.
This method should be overridden if the default behaviour is not compatible with this OverloadedType.
Args:
other (object): The object to multiply `self` with.
Should at the very least accept `float` objects.
Returns:
None
"""
self *= other
def _ad_add(self, other):
"""This method must be overridden.
The method should implement a routine for adding the overloaded object
with another object, and return an object of the same type as `self`.
Args:
other (:obj:`object`): The object to be added with this.
Should at the very least accept objects of the same type as `self`.
Returns:
:obj:`OverloadedType`: The sum of the two objects represented as
an instance of the same subclass of :class:`OverloadedType` as the type
of `self`.
"""
raise NotImplementedError
def _ad_iadd(self, other):
"""In-place adds `other` to `self`.
This method should be overridden if the default behaviour is not compatible with this OverloadedType.
Args:
other (object): The object to multiply `self` with.
Should at the very least accept objects of the same type as `self`.
Returns:
None
"""
self += other
def _ad_dot(self, other):
"""This method must be overridden.
The method should implement a routine for computing the dot product of
the overloaded object with another object of the same type, and return
a :obj:`float`.
Args:
other (:obj:`OverloadedType`): The object to compute the dot product with.
Should be of the same type as `self`.
Returns:
:obj:`float`: The dot product of the two objects.
"""
raise NotImplementedError
def _ad_will_add_as_dependency(self):
"""Method called when the object is added as a Block dependency.
"""
self.block_variable.save_output(overwrite=False)
def _ad_will_add_as_output(self):
"""Method called when the object is added as a Block output.
Returns:
bool: True if the saved checkpoint should be overwritten.
"""
return True
@staticmethod
def _ad_assign_numpy(dst, src, offset):
"""This method must be overridden.
The method should implement a routine for assigning the values from
a numpy array `src` to the checkpoint `dst`. `dst` should be an instance
of the implementing class.
Args:
dst (obj): The object which should be assigned new values.
The type will most likely be an OverloadedType or similar.
src (numpy.ndarray): The numpy array to use as a source for the assignment.
`src` should have the same underlying dimensions as `dst`.
offset (int): Start reading `dst` from `offset`.
Returns:
tuple:
obj: The `dst` object. If `dst` is mutable it is preferred to be the same
instance as supplied to the function call. Otherwise a new instance
must be initialized and returned with the correct `src` values.
int: The new offset.
"""
raise NotImplementedError
@staticmethod
def _ad_to_list(m):
"""This method must be overridden.
The method should implement a routine for converting `m` into a
list type. `m` should be an instance of the same type as the class
this method is implemented in. Although maybe the backend version
of this class, meaning it is not necessarily an OverloadedType.
Args:
m (obj): The object to be converted into a list.
Returns:
list: A list representation of the data structure of `m`.
"""
raise NotImplementedError
def _ad_copy(self):
"""This method must be overridden.
The method should implement a routine for copying itself.
Returns:
OverloadedType: A (deep) copy of `self`.
"""
raise NotImplementedError
def _ad_dim(self):
"""This method must be overridden.
The method should implement a routine for computing the number of components
of `self`.
Returns:
int: The number of components of `self`.
"""
raise NotImplementedError
class FloatingType(OverloadedType):
def __init__(self, *args, **kwargs):
self.block_class = kwargs.pop("block_class", None)
self._ad_args = kwargs.pop("_ad_args", [])
self._ad_kwargs = kwargs.pop("_ad_kwargs", {})
self.ad_block_tag = kwargs.pop("ad_block_tag", None)
self._ad_floating_active = kwargs.pop("_ad_floating_active", False)
self.block = None
self._ad_output_args = kwargs.pop("_ad_output_args", [])
self._ad_output_kwargs = kwargs.pop("_ad_output_kwargs", {})
self.output_block_class = kwargs.pop("output_block_class", None)
self._ad_outputs = kwargs.pop("_ad_outputs", [])
OverloadedType.__init__(self, *args, **kwargs)
def create_block_variable(self):
block_variable = OverloadedType.create_block_variable(self)
block_variable.floating_type = True
return block_variable
def _ad_will_add_as_dependency(self):
if self._ad_floating_active:
with FloatingType.stop_floating(self):
self._ad_annotate_block()
self.block_variable.save_output(overwrite=True)
def _ad_will_add_as_output(self):
if self._ad_floating_active:
with FloatingType.stop_floating(self):
self._ad_annotate_output_block()
return True
def _ad_annotate_block(self):
if self.block_class is None:
return
tape = get_working_tape()
block = self.block_class(*self._ad_args, **self._ad_kwargs)
block.tag = self.ad_block_tag
self.block = block
tape.add_block(block)
block.add_output(self.create_block_variable())
def _ad_annotate_output_block(self):
if self.output_block_class is None:
return
tape = get_working_tape()
block = self.output_block_class(self, *self._ad_output_args, **self._ad_output_kwargs)
self.output_block = block
tape.add_block(block)
for output in self._ad_outputs:
block.add_output(output.create_block_variable())
class stop_floating(object):
def __init__(self, obj):
self.obj = obj
def __enter__(self):
self.obj._ad_floating_active = False
def __exit__(self, *args):
self.obj._ad_floating_active = True