Source code for pyadjoint_utils.reduced_function

from pyadjoint.enlisting import Enlist
from pyadjoint.tape import stop_annotating, get_working_tape

from .drivers import (
    compute_gradient,
    compute_jacobian_action,
    compute_jacobian_matrix,
    compute_hessian_action,
)

# This is copied and extended from work by Sebastian Mitusch in pyadjoint
# [https://bitbucket.org/dolfin-adjoint/pyadjoint/commits/4b67f3e07579501ab95b1fad143d98253a15f7ae?at=reduced-function]


[docs]class ReducedFunction(object): def __init__( self, outputs, controls, tape=None, eval_cb_pre=None, eval_cb_post=None, jac_action_cb_pre=None, jac_action_cb_post=None, adj_jac_action_cb_pre=None, adj_jac_action_cb_post=None, hess_action_cb_pre=None, hess_action_cb_post=None, ): self.output_vals = Enlist(outputs) self.outputs = Enlist([out.block_variable for out in self.output_vals]) self.outputs.listed = self.output_vals.listed self.controls = Enlist(controls) self.tape = get_working_tape() if tape is None else tape nothing = lambda *args: None self.eval_cb_pre = nothing if eval_cb_pre is None else eval_cb_pre self.eval_cb_post = nothing if eval_cb_post is None else eval_cb_post self.jac_action_cb_pre = ( nothing if jac_action_cb_pre is None else jac_action_cb_pre ) self.jac_action_cb_post = ( nothing if jac_action_cb_post is None else jac_action_cb_post ) self.adj_jac_action_cb_pre = ( nothing if adj_jac_action_cb_pre is None else adj_jac_action_cb_pre ) self.adj_jac_action_cb_post = ( nothing if adj_jac_action_cb_post is None else adj_jac_action_cb_post ) self.hess_action_cb_pre = ( nothing if hess_action_cb_pre is None else hess_action_cb_pre ) self.hess_action_cb_post = ( nothing if hess_action_cb_post is None else hess_action_cb_post )
[docs] def jac_action(self, inputs, options=None): inputs = Enlist(inputs) if len(inputs) != len(self.controls): raise TypeError( "The length of inputs must match the length of function controls." ) values = [c.data() for c in self.controls] self.jac_action_cb_pre( self.controls.delist(values), self.controls.delist(inputs) ) derivatives = compute_jacobian_action( self.output_vals, self.controls, inputs, options=options, tape=self.tape ) # Call callback self.jac_action_cb_post( self.outputs.delist([bv.saved_output for bv in self.outputs]), self.outputs.delist(derivatives), self.controls.delist(values), ) return self.outputs.delist(derivatives)
[docs] def adj_jac_action(self, inputs, options=None): inputs = Enlist(inputs) if len(inputs) != len(self.outputs): raise TypeError( "The length of inputs must match the length of function outputs." ) values = [c.data() for c in self.controls] self.adj_jac_action_cb_pre(self.controls.delist(values)) derivatives = compute_gradient( self.output_vals, self.controls, options=options, tape=self.tape, adj_value=inputs, ) # Call callback self.adj_jac_action_cb_post( self.outputs.delist([bv.saved_output for bv in self.outputs]), self.controls.delist(derivatives), self.controls.delist(values), ) return self.controls.delist(derivatives)
[docs] def jac_matrix(self, m_jac=None): if m_jac is not None: m_jac = Enlist(m_jac) if len(m_jac) != len(self.controls): raise TypeError( "The length of m_jac must match the length of function controls." ) for i, jac in enumerate(m_jac): m_jac[i] = Enlist(jac) if len(m_jac[i]) != len(self.controls): raise TypeError( "The length of each identity must match the length of function controls." ) jacobian = compute_jacobian_matrix( self.output_vals, self.controls, m_jac, tape=self.tape ) for i, jac in enumerate(jacobian): if jac is not None: jacobian[i] = self.controls.delist(jac) jacobian = self.outputs.delist(jacobian) return jacobian
def hess_action(self, m_dot, adj_input, options=None): m_dot = Enlist(m_dot) if len(m_dot) != len(self.controls): raise TypeError( "The length of m_dot must match the length of function controls." ) adj_input = Enlist(adj_input) if len(adj_input) != len(self.outputs): raise TypeError( "The length of adj_input must match the length of function outputs." ) values = [c.data() for c in self.controls] self.hess_action_cb_pre(self.controls.delist(values)) derivatives = compute_gradient( self.output_vals, self.controls, options=options, tape=self.tape, adj_value=adj_input, ) # TODO: there should be a better way of generating hessian_input. zero = [0 * v for v in adj_input] hessian = compute_hessian_action( self.output_vals, self.controls, m_dot, options=options, tape=self.tape, hessian_value=zero, ) # Call callback self.hess_action_cb_post( self.outputs.delist([bv.saved_output for bv in self.outputs]), self.controls.delist(hessian), self.controls.delist(values), ) return self.controls.delist(hessian)
[docs] def __call__(self, inputs): inputs = Enlist(inputs) if len(inputs) != len(self.controls): raise TypeError("The length of inputs must match the length of controls.") # Call callback. self.eval_cb_pre(self.controls.delist(inputs)) for i, value in enumerate(inputs): self.controls[i].update(value) # self.tape.reset_blocks() with self.marked_controls(): with stop_annotating(): self.tape.recompute() output_vals = self.outputs.delist( [output.checkpoint for output in self.outputs] ) # Call callback self.eval_cb_post(output_vals, self.controls.delist(inputs)) return output_vals
def marked_controls(self): return marked_controls(self)
class marked_controls(object): def __init__(self, rf): self.rf = rf def __enter__(self): for control in self.rf.controls: control.mark_as_control() def __exit__(self, *args): for control in self.rf.controls: control.unmark_as_control()