Source code for pyadjoint_utils.fenics_adjoint.assembly

from ufl import Form
from fenics_adjoint import backend, Function
from fenics_adjoint.assembly import AssembleBlock as AssembleBlockOrig
from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape
from pyadjoint.overloaded_type import OverloadedType, create_overloaded_object

import backend
from fenics_adjoint.compat import compat

compat = compat(backend)


class OverloadedUFLArgument(object):
    """An overloaded argument whose adjoint values on a tape should be returned as rank-1 Forms"""

    pass


[docs]def assemble(*args, **kwargs): """Use the assemble syntax of firedrake, where the 'tensor' kwarg can take a Function. If that is the case (or if firedrake is the backend and the returned tensor is a Function), convert the result to an overloaded function so that an AssembleBlock can be created, even if the output is not a scalar.""" annotate = annotate_tape(kwargs) with stop_annotating(): if "tensor" in kwargs and isinstance(kwargs["tensor"], backend.Function): outputfunc = kwargs["tensor"] new_kwargs = kwargs.copy() new_kwargs["tensor"] = outputfunc.vector() backend.assemble(*args, **new_kwargs) output = outputfunc else: output = backend.assemble(*args, **kwargs) form = args[0] if isinstance(output, float) or isinstance(output, backend.Function): output = create_overloaded_object(output) if annotate: block = AssembleBlock(form) tape = get_working_tape() tape.add_block(block) block.add_output(output.block_variable) else: # Assembled a matrix output.form = form return output
class AssembleBlock(AssembleBlockOrig): def __init__(self, form): super(AssembleBlockOrig, self).__init__(form) self.form = form if backend.__name__ != "firedrake": mesh = self.form.ufl_domain().ufl_cargo() else: mesh = self.form.ufl_domain() self.add_dependency(mesh) for c in self.form.coefficients(): if isinstance(c, OverloadedType): self.add_dependency(c, no_duplicates=True) def evaluate_adj_component( self, inputs, adj_inputs, block_variable, idx, prepared=None ): """catch when the adj_input is a Function, or when the return value should be a Form, otherwise do the default""" form = prepared adj_input = adj_inputs[0] c = block_variable.output c_rep = block_variable.saved_output args = form.arguments() n_args = len(args) if isinstance(c, OverloadedUFLArgument): if n_args == 0: form = adj_input * form elif n_args == 1: form = form * adj_input if isinstance(c, backend.Function): V = c.function_space() else: mesh = compat.extract_mesh_from_form(self.form) V = c._ad_function_space(mesh) dc = backend.TestFunction(V) dform = backend.derivative(form, c_rep, dc) return dform if n_args == 0: return super(AssembleBlock, self).evaluate_adj_component( inputs, adj_inputs, block_variable, idx, prepared=form ) # In this case, adj_input may be a vector. It needs to be made into a function to put it into the form. if isinstance(adj_input, backend.Function): adj_input_func = adj_input else: # import crikit.utils as utils; utils.debug(locals(), globals()) adj_input_func = backend.Function(args[0].function_space()) adj_input_func.vector()[:] = adj_input[:] # if the output of the original is an assembled function, then # it should have exactly one argument, so we can apply it to adj_input. form = form * adj_input_func if isinstance(c, compat.ExpressionType): # Create a FunctionSpace from self.form and Expression. # And then make a TestFunction from this space. mesh = self.form.ufl_domain().ufl_cargo() V = c._ad_function_space(mesh) dc = backend.TestFunction(V) dform = backend.derivative(form, c_rep, dc) # print(c) # import crikit.utils as utils; utils.debug(locals(), globals()) # output = backend.Function(c.function_space()) # backend.assemble(dform,tensor=output.vector()) output = compat.assemble_adjoint_value(dform) return [[output, V]] elif isinstance(c, compat.MeshType): X = backend.SpatialCoordinate(c_rep) dform = backend.derivative(form, X) output = compat.assemble_adjoint_value(dform) return output if isinstance(c, backend.Function): V = c.function_space() else: mesh = compat.extract_mesh_from_form(self.form) V = c._ad_function_space(mesh) dc = backend.TestFunction(V) dform = backend.derivative(form, c_rep, dc) output = compat.assemble_adjoint_value(dform) return output def evaluate_tlm_component( self, inputs, tlm_inputs, block_variable, idx, prepared=None ): """catch when the output is a Function, otherwise do the default""" output = super(AssembleBlock, self).evaluate_tlm_component( inputs, tlm_inputs, block_variable, idx, prepared=prepared ) saved_output = self.get_outputs()[idx].saved_output if isinstance(saved_output, backend.Function): output = backend.Function(saved_output.function_space(), output) return output def evaluate_hessian_component( self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None, ): form = prepared hessian_input = hessian_inputs[0] adj_input = adj_inputs[0] c1 = block_variable.output c1_rep = block_variable.saved_output args = form.arguments() n_args = len(args) if n_args > 0 or isinstance( c1, OverloadedUFLArgument ): # or isinstance(adj_input, backend.Function): if isinstance(adj_input, backend.Function): adj_input_func = adj_input else: adj_input_func = backend.Function(args[0].function_space()) adj_input_func.vector()[:] = adj_input[:] if isinstance(hessian_input, backend.Function): hessian_input_func = hessian_input else: hessian_input_func = backend.Function(args[0].function_space()) hessian_input_func.vector()[:] = hessian_input[:] if n_args == 0: hform = adj_input * form aform = hessian_input * form else: hform = form * hessian_input_func aform = form * adj_input_func mesh = self.form.ufl_domain().ufl_cargo() if isinstance(c1, backend.Function): V = c1.function_space() else: V = c1._ad_function_space(mesh) dc = backend.TestFunction(V) dform = backend.derivative(hform, c1_rep, dc) dforma = backend.derivative(aform, c1_rep, dc) for other_idx, bv in relevant_dependencies: c2_rep = bv.saved_output tlm_input = bv.tlm_value if tlm_input is None: continue if isinstance(c2_rep, compat.MeshType): X = backend.SpatialCoordinate(c2_rep) ddform = backend.derivative(dforma, X, tlm_input) else: ddform = backend.derivative(dforma, c2_rep, tlm_input) dform += ddform if not isinstance(c1, OverloadedUFLArgument): return compat.assemble_adjoint_value(dform) return dform return super(AssembleBlock, self).evaluate_hessian_component( inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared, ) def recompute_component(self, inputs, block_variable, idx, prepared): saved_output = self.get_outputs()[idx].saved_output if isinstance(saved_output, backend.Function): form = prepared V = saved_output.function_space() output = backend.Function(V) backend.assemble(form, tensor=output.vector()) output = create_overloaded_object(output) return output else: return super(AssembleBlock, self).recompute_component( inputs, block_variable, idx, prepared )