from pyadjoint import Tape as pyadjoint_tape
from pyadjoint import get_working_tape, set_working_tape, Block
from contextlib import contextmanager
[docs]@contextmanager
def push_tape():
"""Creates a new tape in its scope that is a sub-tape of the current working tape"""
orig_tape = get_working_tape()
new_tape = Tape()
set_working_tape(new_tape)
yield new_tape
set_working_tape(orig_tape)
def _find_relevant_nodes(tape, controls_or_block_variables):
# This function is just a stripped down Block.optimize_for_controls
blocks = tape.get_blocks()
if len(controls_or_block_variables) > 0 and hasattr(
next(iter(controls_or_block_variables)), "block_variable"
):
nodes = set([control.block_variable for control in controls_or_block_variables])
else:
nodes = set([block_variable for block_variable in controls_or_block_variables])
for block in blocks:
depends_on_control = False
for dep in block.get_dependencies():
if dep in nodes:
depends_on_control = True
if depends_on_control:
for output in block.get_outputs():
nodes.add(output)
return nodes
def _block_get_tf_blocks(block):
if not hasattr(block, "tf_get_blocks"):
return block
tf_blocks = block.tf_get_blocks()
if isinstance(tf_blocks, Block):
return tf_blocks
block, sub_blocks = tf_blocks
sub_blocks = list(sub_blocks)
for i, sub_block in enumerate(sub_blocks):
sub_blocks[i] = _block_get_tf_blocks(sub_block)
return [block, sub_blocks]
[docs]class Tape(pyadjoint_tape):
__slots__ = ["_tf_output_block_lookup"]
def __init__(self, *args, **kwargs):
self._tf_output_block_lookup = {}
super().__init__(*args, **kwargs)
def evaluate_adj(self, inputs=None, outputs=None, markings=False):
nodes, blocks = self.find_relevant_nodes(inputs, outputs)
for block in reversed(blocks):
block.evaluate_adj(markings=markings)
def evaluate_tlm(self, inputs=None, outputs=None, markings=False):
nodes, blocks = self.find_relevant_nodes(inputs, outputs)
with self.marked_nodes(nodes):
for block in blocks:
block.evaluate_tlm(markings=True)
def evaluate_tlm_matrix(self, inputs=None, outputs=None, markings=False):
nodes, blocks = self.find_relevant_nodes(inputs, outputs)
with self.marked_nodes(nodes, find_outputs=False):
for block in blocks:
block.evaluate_tlm_matrix(markings=True)
def evaluate_hessian(self, inputs=None, outputs=None, markings=False):
nodes, blocks = self.find_relevant_nodes(inputs, outputs)
for block in reversed(blocks):
block.evaluate_hessian(markings=markings)
def recompute(self, inputs=None, outputs=None):
nodes, blocks = self.find_relevant_nodes(inputs, outputs)
for block in blocks:
block.recompute()
# print(f"output of block {block} is {block.get_outputs()}")
def reset_tlm_matrix_values(self):
for block in reversed(self._blocks):
block.reset_variables(types=("tlm_matrix"))
def find_relevant_dependencies(self, outputs):
blocks = self.get_blocks()
if len(outputs) > 0 and hasattr(next(iter(outputs)), "block_variable"):
nodes = set([output.block_variable for output in outputs])
else:
nodes = set(outputs)
relevant_blocks = [False] * len(blocks)
for i in range(len(blocks) - 1, -1, -1):
block = blocks[i]
produces_output = False
for dep in block.get_outputs():
if dep in nodes:
produces_output = True
relevant_blocks[i] = produces_output
if produces_output:
for dep in block.get_dependencies():
nodes.add(dep)
return nodes, relevant_blocks
def find_relevant_outputs(self, inputs):
blocks = self.get_blocks()
if len(inputs) > 0 and hasattr(next(iter(inputs)), "block_variable"):
nodes = set([control.block_variable for control in inputs])
else:
nodes = set(inputs)
relevant_blocks_mask = []
for block in blocks:
depends_on_control = False
for dep in block.get_dependencies():
if dep in nodes:
depends_on_control = True
relevant_blocks_mask.append(depends_on_control)
if depends_on_control:
for output in block.get_outputs():
nodes.add(output)
return nodes, relevant_blocks_mask
def find_relevant_nodes(self, inputs=None, outputs=None):
# TODO: double check that inputs and outputs are block variables or controls.
blocks = self.get_blocks()
if inputs is None:
if outputs is None:
nodes = set()
for block in blocks:
for dep in block.get_dependencies():
nodes.add(dep)
for output in block.get_outputs():
nodes.add(output)
return nodes, blocks
else:
nodes, relevant_blocks_masks = self.find_relevant_dependencies(outputs)
relevant_blocks = [
block for i, block in enumerate(blocks) if relevant_blocks_masks[i]
]
return nodes, relevant_blocks
else:
if outputs is None:
nodes, relevant_blocks_masks = self.find_relevant_outputs(inputs)
relevant_blocks = [
block for i, block in enumerate(blocks) if relevant_blocks_masks[i]
]
return nodes, relevant_blocks
else:
i_nodes, i_blocks = self.find_relevant_outputs(inputs)
o_nodes, o_blocks = self.find_relevant_dependencies(outputs)
nodes = i_nodes.intersection(o_nodes)
relevant_blocks_masks = [i and o for i, o in zip(i_blocks, o_blocks)]
relevant_blocks = [
block for i, block in enumerate(blocks) if relevant_blocks_masks[i]
]
return nodes, relevant_blocks
def find_absolute_dependencies_outputs(self):
dependencies = set()
outputs = set()
for block in self._blocks:
for dep in block.get_dependencies():
if dep not in outputs:
dependencies.add(dep)
for output in block.get_outputs():
outputs.add(output)
return dependencies, outputs
# This function was modified to work directly with block variables.
@contextmanager
def marked_nodes(self, controls_or_block_variables, find_outputs=True):
if find_outputs:
nodes = _find_relevant_nodes(self, controls_or_block_variables)
else:
nodes = controls_or_block_variables
old_values = [node.marked_in_path for node in nodes]
for node in nodes:
node.marked_in_path = True
yield
for node, old_value in zip(nodes, old_values):
node.marked_in_path = old_value
@contextmanager
def save_adj_values(self):
nodes, _ = self.find_relevant_nodes()
old_values = [node.adj_value for node in nodes]
for node in nodes:
node.adj_value = None
yield
for node, old_value in zip(nodes, old_values):
node.adj_value = old_value
def _get_tf_scope_name(self, node):
"""Return a TensorFlow scope name based on the node's class name or an attribute 'tf_name'."""
# If the block is a BlockVariable we use block.output
if node.__class__.__name__ == "BlockVariable":
node = node.output
if hasattr(node, "tf_name"):
name = node.tf_name
else:
name = node.__class__.__name__
return self._valid_tf_scope_name(name)
def _tf_register_blocks(self, name=None):
lst = [name]
for block in self.get_blocks():
if block in self._tf_added_blocks:
continue
self._tf_added_blocks.append(block)
lst.append(_block_get_tf_blocks(block))
self._tf_registered_blocks.append(lst)
def _tf_rebuild_registered_blocks(self):
"""Remove blocks that no longer exist on the tape from registered blocks."""
new_registered_blocks = []
new_added_blocks = []
for scope in self._tf_registered_blocks:
lst = scope[:1]
for block in scope[1:]:
if isinstance(block, Block):
if block in self.get_blocks():
lst.append(block)
else:
block, sub_blocks = block
if block in self.get_blocks():
lst.append(_block_get_tf_blocks(block))
new_added_blocks.append(block)
if len(lst) > 1:
new_registered_blocks.append(lst)
self._tf_registered_blocks = new_registered_blocks
self._tf_added_blocks = new_added_blocks
def _tf_add_blocks_scoped(self, blocks):
"""Add the given blocks (with possible sub-blocks) to the TensorFlow graph."""
import tensorflow.compat.v1 as tf
for block in blocks:
self._tf_add_block(block)
for block in blocks:
if not isinstance(block, Block):
block = block[0]
for out in block.get_outputs():
if id(out) not in self._tf_tensors:
t = self._tf_output_block_lookup.get(id(out), None)
if t is None:
# This block output wasn't created in the tensorflow
# graph, and I have no way to create it because I don't
# know what block generated the corresponding tensor.
raise ValueError(
"This output is expected to be created in a sub-block: {}".format(
str(out)
)
)
with tf.name_scope(self._get_tf_scope_name(out)):
tout = tf.py_function(
lambda: None,
[t],
[tf.float64],
name=self._valid_tf_scope_name(str(out)),
)
self._tf_tensors[id(out)] = tout
def visualise(self, output="log", *args, **kwargs):
"""This resets the TensorFlow data, which allows a user to call tape.visualise() twice without error."""
if not output.endswith(".dot"):
self._tf_tensors = {}
self._tf_added_blocks = []
self._tf_registered_blocks = []
return super().visualise(output, *args, **kwargs)
def _tf_add_blocks(self):
"""Add new blocks to the TensorFlow graph while supporting the name_scope() method."""
import tensorflow as tf
self._tf_register_blocks()
self._tf_rebuild_registered_blocks()
for scope in self._tf_registered_blocks:
scope_name = scope[0]
if scope_name is None:
self._tf_add_blocks_scoped(scope[1:])
else:
with tf.name_scope(scope_name):
self._tf_add_blocks_scoped(scope[1:])
def _tf_add_block(self, block, sub_blocks=None):
"""Add a block to the TensorFlow graph, and recursively add its sub-blocks."""
import tensorflow as tf
if not isinstance(block, Block):
if sub_blocks is not None:
raise ValueError(
"If the input is not a Block, then it should be a tuple (block, sub_blocks) and the sub_blocks kwarg must be None"
)
block, sub_blocks = block
# Block dependencies
in_tensors = []
for dep in block.get_dependencies():
if id(dep) in self._tf_tensors:
in_tensors.append(self._tf_tensors[id(dep)])
else:
# Look up the block variable in _tf_output_block_lookup to
# connect it to a previously recorded block if necessary.
inputs = self._tf_output_block_lookup.get(id(dep), [])
with tf.name_scope(self._get_tf_scope_name(dep)):
tin = tf.numpy_function(
lambda: None,
inputs,
[tf.float64],
name=self._valid_tf_scope_name(str(dep)),
)
in_tensors.append(tin)
self._tf_tensors[id(dep)] = tin
# Block node
with tf.name_scope(self._get_tf_scope_name(block)):
if sub_blocks is None:
def tf_np_f(*args):
return None
tensor = tf.numpy_function(
tf_np_f,
in_tensors,
[tf.float64],
name=self._valid_tf_scope_name(str(block)),
)
self._tf_tensors[id(block)] = tensor
else:
for sub_block in sub_blocks:
self._tf_add_block(sub_block)
if hasattr(block, "tf_add_extra_to_graph"):
block.tf_add_extra_to_graph(self._tf_tensors)
# Block outputs.
# To avoid incorrectly scoping the block outputs, these will be added
# once they are used as input. The output tensor of the current block must
# be saved to be used as input when the block outputs are created.
if sub_blocks is None:
for out in block.get_outputs():
self._tf_output_block_lookup[id(out)] = tensor