# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """FuncGraph and related functionality.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections as py_collections import itertools import weakref import numpy as np from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import compat from tensorflow.python.util import memory from tensorflow.python.util import nest from tensorflow.python.util import object_identity from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator ALLOWLIST_COLLECTIONS = [ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES, variable_scope._VARSTORE_KEY, # pylint: disable=protected-access variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access ] _EAGER_CONST_THRESHOLD = 128 class UnknownArgument(object): """Signifies an argument which is not currently handled.""" pass def convert_structure_to_signature(structure, arg_names=None): """Convert a potentially nested structure to a signature. Args: structure: Structure to convert, where top level collection is a list or a tuple. arg_names: Optional list of arguments that has equal number of elements as `structure` and is used for naming corresponding TensorSpecs. Returns: Identical structure that has TensorSpec objects instead of Tensors and UnknownArgument instead of any unsupported types. """ def encode_arg(arg, path): """A representation for this argument, for converting into signatures.""" if isinstance(arg, ops.Tensor): user_specified_name = None try: user_specified_name = compat.as_str( arg.op.get_attr("_user_specified_name")) except ValueError: pass if path and user_specified_name and user_specified_name != path[0]: # The user has explicitly named the argument differently than the name # of the function argument. name = user_specified_name else: name = "/".join(str(p) for p in path) return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) if isinstance(arg, composite_tensor.CompositeTensor): # TODO(b/133606651) Do we need to inject arg_name? return arg._type_spec # pylint: disable=protected-access if isinstance(arg, resource_variable_ops.BaseResourceVariable): name = "/".join(str(p) for p in path) return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name) if isinstance(arg, ( int, float, bool, type(None), dtypes.DType, tensor_spec.TensorSpec, type_spec.TypeSpec, )): return arg return UnknownArgument() # We are using the flattened paths to name the TensorSpecs. We need an # explicit name for them downstream. flattened = nest.flatten_with_tuple_paths(structure) if arg_names: if len(arg_names) != len(structure): raise ValueError( "Passed in arg_names don't match actual signature (%s)." % arg_names) # Replace all top-level names with their actual arg_names. If a path before # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". flattened = [ ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened ] mapped = [encode_arg(arg, path) for path, arg in flattened] return nest.pack_sequence_as(structure, mapped) class FuncGraph(ops.Graph): """Graph representing a function body. Attributes: name: The name of the function. inputs: Placeholder tensors representing the inputs to this function. The tensors are in this FuncGraph. This represents "regular" inputs as well as captured inputs (i.e. the values of self.captures), with the regular inputs coming first. outputs: Tensors that will be returned by this function. The tensors are in this FuncGraph. control_outputs: Operations that must be executed before the function represented by this graph can be said to have been executed. structured_input_signature: A tuple of (args, kwargs), which are both possibly-nested python objects that were received by this function. Note that these structures might contain Python `None`s. structured_outputs: A possibly-nested python object which will be returned by this function. The Tensors in this structure are the same as those of self.outputs. Note that this structure might contain Python `None`s. variables: Variables that should be watched during function execution. outer_graph: The graph this function is defined in. May be another FuncGraph or the global default Graph. captures: Maps external tensor -> internal tensor (i.e. input placeholder). The entries are in the order they were captured. control_captures: Set of external ops on which this graph has a control dependency. seed: The graph-level random seed. capture_by_value: If True, the func graph will capture Variables by value instead of reference. """ def __init__(self, name, collections=None, capture_by_value=None): """Construct a new FuncGraph. The graph will inherit its graph key, collections, seed, and distribution strategy stack from the current context or graph. Args: name: the name of the function. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write to) the outer graph's collections that are not allowlisted, and both read and write to the outer graph's collections that are allowlisted. The current allowlisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. capture_by_value: An optional boolean. If True, the func graph will capture Variables by value instead of reference. By default inherit from outer graphs, and failing that will default to False. """ super(FuncGraph, self).__init__() self.name = name self.inputs = [] self.outputs = [] self.control_outputs = [] self.control_captures = set() self.structured_input_signature = None self.structured_outputs = None self._weak_variables = [] self._watched_variables = object_identity.ObjectIdentityWeakSet() self.is_control_flow_graph = False outer_graph = ops.get_default_graph() self._weak_outer_graph = weakref.ref(outer_graph) while outer_graph.building_function: outer_graph = outer_graph.outer_graph # If self._weak_outer_graph is deleted, we revert to the outermost Graph # active when the FuncGraph was traced. This will not be a FuncGraph. self._fallback_outer_graph = outer_graph self._captures = py_collections.OrderedDict() # If not None, records the names of output args of this function. Used to # preserve the output names in the signature of a serialized+deserialized # function. Private at the moment mostly because it's often out of date. self._output_names = None # Maps arbitrary key -> (closure, nest of placeholders), where at function # call time the value of closure() will be used to feed the nest of # placeholders. self._deferred_captures = py_collections.OrderedDict() # Inherit capture-by-value from outer graph. if capture_by_value is not None: self.capture_by_value = capture_by_value elif self.outer_graph is not None and isinstance( self.outer_graph, FuncGraph): self.capture_by_value = self.outer_graph.capture_by_value else: self.capture_by_value = False self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} graph = self.outer_graph if context.executing_eagerly(): self.seed = context.global_seed() # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of # any None op_seed for random_op in the function, in which case we end up # using function seed, which could be unintended behavior for the op. self._seed_used = False else: self.seed = graph.seed self._seed_used = False # TODO(allenl): Figure out if we can remove colocation stack # specialization (currently used in cond_v2), here and in the cache key. self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access if collections is None: for collection_name in graph.get_all_collection_keys(): if collection_name not in ALLOWLIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) for collection_name in ALLOWLIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) else: self._collections = collections # Keep track of whether this FuncGraph is exportable to SavedModel. Use # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any # dependent functions as unsaveable. self._saveable = True self._saving_errors = set() # Keep track of callbacks to run when this graph exits default scope self._scope_exit_callbacks = None def __str__(self): return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) def watch_variable(self, v): """Marks the variable v as accessed while building this graph.""" while self is not None and isinstance(self, FuncGraph): self._watched_variables.add(v) self = self.outer_graph def capture_call_time_value(self, closure, spec, key=None): """Creates a placeholder which at call time has the value closure(). Useful, for example, to respect TensorFlow context managers, which are often dynamically scoped. Args: closure: function which takes no arguments, to be evaluated at function call time, returning a nest of tensors compatible with `spec`. spec: nest of TypeSpec for the value to capture. key: optional. If not None, multiple calls to lazy_capture with the same key in the same graph will return the same placeholder, and the first closure will be used at function call time. Returns: Nest of placeholders which, at function call time, will be fed with the result of calling closure(). Raises: ValueError: at function call time, if the return value of closure() is not compatible with `spec`. """ if key is None: key = object() if key not in self._deferred_captures: def convert_to_placeholder(s): if not isinstance(s, tensor_spec.DenseSpec): raise TypeError( "Expected a nest of `TypeSpec` objects, found %s of type %s." % (s, type(s))) return array_ops.placeholder(dtype=s.dtype, shape=s.shape) placeholder = nest.map_structure( convert_to_placeholder, spec, expand_composites=True) def wrapped_closure(): ret_nest = closure() nest.assert_same_structure(spec, ret_nest, expand_composites=True) # This uses the tensor dtype defined in `spec` when converting values # in `ret_nest` to tensors. # pylint: disable=protected-access y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest, expand_composites=False) # pylint: enable=protected-access return nest.flatten(y, expand_composites=True) self._deferred_captures[key] = (wrapped_closure, placeholder) return self._deferred_captures[key][1] def control_dependencies(self, control_inputs): """Handles control dependencies. FuncGraph wraps Graph's control_dependencies logic by first filtering out any external tensors / operations and storing them in the graph's control_captures member. Any consumers of this function graph must then decide how to handle the control captures. Args: control_inputs: A list of `Operation` or `Tensor` objects which must be executed or computed before running the operations defined in the context. Can also be `None` to clear the control dependencies. Returns: A context manager that specifies control dependencies for all operations constructed within the context. Raises: TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` objects. """ if control_inputs is None: return super(FuncGraph, self).control_dependencies(control_inputs) filtered_control_inputs = [] for c in control_inputs: # Check for _UnreadVariable if (isinstance(c, ops.IndexedSlices) or (hasattr(c, "_handle") and hasattr(c, "op"))): c = c.op graph_element = ops._as_graph_element(c) # pylint: disable=protected-access if graph_element is None: graph_element = c if graph_element is not None and getattr( graph_element, "graph", None) is not self: self.control_captures.add(graph_element) else: filtered_control_inputs.append(graph_element) return super(FuncGraph, self).control_dependencies(filtered_control_inputs) def as_default(self): outer_cm = super(FuncGraph, self).as_default() @tf_contextlib.contextmanager def inner_cm(): """Context manager for copying distribute.Strategy scope information.""" # pylint: disable=protected-access # TODO(b/112906995, nareshmodi): distribution strategy depends on # inheriting this stack from the default graph even in eager mode. Maybe # it should be part of the eager context? This would also allow us to # remove a get_default_graph() call from the function cache lookup. graph = ops.get_default_graph() old_strategy_stack = self._distribution_strategy_stack self._distribution_strategy_stack = list( graph._distribution_strategy_stack) # We ignore device placements from any outer scopes while tracing the # function when possible, to avoid hard-coding them in the function # graph. "Default" placements come from the PartitionedCallOp's placement, # so that the same trace of the Python function may be placed on several # different devices and saved functions may be placed on new devices when # restored. # However, we need to preserve the outer device stack in the following # cases in non eager context: # 1. device stack is callable # 2. When using distribution strategy with legacy graph mode. old_device_stack = self._device_function_stack if (not context.executing_eagerly() and (device_stack_has_callable(graph._device_function_stack) or (self._distribution_strategy_stack and not ops.executing_eagerly_outside_functions()))): # Hard-code devices from device functions in the function body self._device_function_stack = graph._device_function_stack.copy() old_creator_stack = self._variable_creator_stack self._variable_creator_stack = graph._variable_creator_stack # Inherit the graph key, since this is used for matching variables in # optimizers. old_graph_key = self._graph_key self._graph_key = graph._graph_key # pylint: enable=protected-access old_scope_exit_callbacks = self._scope_exit_callbacks self._scope_exit_callbacks = [] with outer_cm as g: try: yield g finally: try: for fn in self._scope_exit_callbacks: fn() finally: self._scope_exit_callbacks = old_scope_exit_callbacks self._distribution_strategy_stack = old_strategy_stack self._device_function_stack = old_device_stack self._variable_creator_stack = old_creator_stack self._graph_key = old_graph_key return inner_cm() @property def outer_graph(self): """The Graph this FuncGraph is nested in. Functions may capture Tensors from graphs they are nested in (transitive). Returns: A Graph object. Initially set to the current default graph when the FuncGraph was created. If the previous `outer_graph` was deleted because the function that owns it was deleted, `outer_graph` is reset to the outermost default graph active when the FuncGraph was created. This FuncGraph won't have captured anything from the new `outer_graph` (and likely not from the previous setting, since that would have created a strong reference), but it is returned so that FuncGraphs always have a parent. """ current = self._weak_outer_graph() if current is None: return self._fallback_outer_graph return current @property def output_types(self): return [t.dtype for t in self.outputs] @property def output_shapes(self): return [t.shape for t in self.outputs] @property def trainable_variables(self): """A sequence of trainable variables accessed by this FuncGraph. Note that functions keep only weak references to variables. Calling the function after a variable it accesses has been deleted is an error. Returns: Sequence of trainable variables for this func graph. """ return tuple(v for v in self.variables if v.trainable) @property def variables(self): """A sequence of variables accessed by this FuncGraph. Note that functions keep only weak references to variables. Calling the function after a variable it accesses has been deleted is an error. Returns: Sequence of variables for this func graph. """ def deref(weak_v): v = weak_v() if v is None: raise AssertionError( "Called a function referencing variables which have been deleted. " "This likely means that function-local variables were created and " "not referenced elsewhere in the program. This is generally a " "mistake; consider storing variables in an object attribute on " "first call.") return v return tuple(deref(v) for v in self._weak_variables) @variables.setter def variables(self, var_list): self._weak_variables = [weakref.ref(v) for v in var_list] def _capture_by_value( self, op_type, inputs, dtypes, # pylint: disable=redefined-outer-name input_types=None, name=None, attrs=None, op_def=None, compute_device=True): # When capturing by value, do the read outside reverse_captures = dict((id(v), k) for k, v in self.captures) uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs] with ops.init_scope(): if context.executing_eagerly(): attr_list = ("dtype", int(attrs["dtype"].type)) value, = execute.execute( compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, context.context()) else: op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access op_type, uncaptured_inputs, dtypes, input_types, name, attrs, op_def, compute_device) value = op.outputs[0] captured_value = self.capture(value) return captured_value.op def _create_op_internal( self, op_type, inputs, dtypes=None, # pylint: disable=redefined-outer-name input_types=None, name=None, attrs=None, op_def=None, compute_device=True): """Like Graph.create_op, except handles external input tensors. This overload adds functionality to create_op to "capture" any external input tensors, i.e. tensors from the eager context or outer function graphs if this is a nested function. See `capture` for more information. Args: op_type: The `Operation` type to create. This corresponds to the `OpDef.name` field for the proto that defines the operation. inputs: A list of `Tensor` objects that will be inputs to the `Operation`. dtypes: (Optional) A list of `DType` objects that will be the types of the tensors that the operation produces. input_types: (Optional.) A list of `DType`s that will be the types of the tensors that the operation consumes. By default, uses the base `DType` of each input in `inputs`. Operations that expect reference-typed inputs must specify `input_types` explicitly. name: (Optional.) A string name for the operation. If not specified, a name is generated based on `op_type`. attrs: (Optional.) A dictionary where the key is the attribute name (a string) and the value is the respective `attr` attribute of the `NodeDef` proto that will represent the operation (an `AttrValue` proto). op_def: (Optional.) The `OpDef` proto that describes the `op_type` that the operation will have. compute_device: (Optional.) If True, device functions will be executed to compute the device property of the Operation. Returns: An `Operation` object. """ if self.capture_by_value and op_type in ["ReadVariableOp", "ResourceGather"]: return self._capture_by_value(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device) # This capturing logic interacts poorly with control flow contexts which # want to replace inputs of ops far too late in the process. This can lead # the context to get confused and try to create an Enter for an Enter. We # can detect this here and skip the additional Enter which can confuse loop # validation logic. if op_type == "Enter" and inputs[0].op.type == "Enter": if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: return inputs[0].op # Calling AddValue on the control flow contexts to force creation of the # backward accumulators in the original graph before we create placeholders # to capture the inputs. ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access # Use a different list to avoid modifying the original inputs list. captured_inputs = [] for inp in inputs: # TPU Estimator defines a control flow context with no AddValue method. if ctxt is not None and hasattr(ctxt, "AddValue"): inp = ctxt.AddValue(inp) inp = self.capture(inp) captured_inputs.append(inp) return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access op_type, captured_inputs, dtypes, input_types, name, attrs, op_def, compute_device) def capture(self, tensor, name=None, shape=None): """Captures `tensor` if it's external to this graph. If `tensor` is from a different graph, returns a placeholder for it. `tensor` and the placeholder will appear in self.captures, and the placeholder will appear in self.inputs. Multiple calls to this method with the same `tensor` argument will return the same placeholder. If `tensor` is from this graph, returns `tensor`. Args: tensor: Tensor. May be from this FuncGraph or a different graph. name: Optional name if a placeholder is created. shape: Optional shape if a placeholder is created. Returns: Tensor from this FuncGraph. Raises: InaccessibleTensorError: if any tensors are accessed in a manner that bypasses the mechanisms required for the data dependencies to be correctly wired. """ if isinstance(tensor, ops.EagerTensor): if name is None: name = str(ops.uid()) # Small EagerTensors are captured with Const ops if (tensor.dtype in dtypes.TF_VALUE_DTYPES and np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD): return self.capture_eager_tensor(tensor, name) # Large EagerTensors and resources are captured with Placeholder ops return self._capture_helper(tensor, name, shape) if tensor.graph is not self: if name is None: name = tensor.op.name inner_graph = tensor.graph while inner_graph is not None and isinstance(inner_graph, FuncGraph): if inner_graph is self: raise errors.InaccessibleTensorError( "The tensor '%s' cannot be accessed here: it is defined" " in another function or code block. Use return values," " explicit Python locals or TensorFlow collections to access" " it. Defined in: %s; accessed from: %s.\n" % (tensor, tensor.graph, self)) inner_graph = inner_graph.outer_graph return self._capture_helper(tensor, name) return tensor def _capture_helper(self, tensor, name, shape=None): capture = self._captures.get(id(tensor)) if capture is None: placeholder = _create_substitute_placeholder( tensor, name=name, dtype=tensor.dtype, shape=shape) # Record the composite device as an attribute to the placeholder. # This attribute would be propogated into the arg_attr of the FunctionDef. # Currently, a packed eager tensor is always placed on a CompositeDevice. if isinstance(tensor, ops.EagerTensor) and tensor.is_packed: placeholder.op._set_attr( # pylint: disable=protected-access "_composite_device", attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device))) self.add_capture(tensor, placeholder) else: placeholder = capture[1] tape.record_operation("captured_value", [placeholder], [tensor], backward_function=lambda x: [x], forward_function=lambda x: [x]) return placeholder @property def captures(self): """Order list of tuples containing external and internal captures.""" return self._captures.values() def add_capture(self, tensor, placeholder): """Capture a specific tensor and utilize the provided placeholder. Args: tensor: Tensor to captures. placeholder: Provided placeholder for the tensor. """ self._captures[id(tensor)] = (tensor, placeholder) self.inputs.append(placeholder) def replace_capture(self, tensor, placeholder): """Replace already existing capture.""" self._captures[id(tensor)] = (tensor, placeholder) def reset_captures(self, capture_list): """Set the captures with the provided list of captures & placeholder.""" self._captures = py_collections.OrderedDict() for tensor, placeholder in capture_list: self._captures[id(tensor)] = (tensor, placeholder) def pop_capture(self, tensor): """Remove the capture and return the generated placeholder.""" capture = self._captures.pop(id(tensor), None) if capture is None: return None return capture[1] def clear_captures(self): # TODO(b/115366440): Delete this method when a custom OrderedDict is added. # Clearing captures using clear() leaves some cycles around. while self._captures: self._captures.popitem() memory.dismantle_ordered_dict(self._captures) while self._deferred_captures: self._deferred_captures.popitem() memory.dismantle_ordered_dict(self._deferred_captures) def capture_distributed_variable(self, variable, placeholder): """Add given distributed variable to captures with given placeholder.""" self._captures[id(variable)] = (variable, placeholder) tape.record_operation("captured_value", [placeholder], [variable], backward_function=lambda x: [x], forward_function=lambda x: [x]) def capture_eager_tensor(self, tensor, name): capture = self._captures.get(id(tensor)) if capture is None: # We clear all control dependencies and place the Const op on the same # device as the source tensor. The device placement may be relaxed at # a later date. with ops.control_dependencies(None), self.device(tensor.device): constant_value = tensor_util.constant_value(tensor) if constant_value is None: # Some eager tensors, e.g. parallel tensors, are not convertible to a # single constant. We'll use a placeholder for this case. return self._capture_helper(tensor, name) graph_const = constant_op.constant(constant_value, dtype=tensor.dtype, shape=tensor.shape, name=name) self.add_capture(tensor, graph_const) else: graph_const = capture[1] tape.record_operation("captured_value", [graph_const], [tensor], backward_function=lambda x: [x], forward_function=lambda x: [x]) return graph_const def captured(self, tensor): """Check if the specified tensor has been captured.""" return id(tensor) in self._captures @property def external_captures(self): """External tensors captured by this function.""" return [c[0] for c in self._captures.values()] @property def internal_captures(self): """Placeholders in this function corresponding captured tensors.""" return [c[1] for c in self._captures.values()] @property def deferred_external_captures(self): """Ordered nest of tensors whose placeholders will be fed at call time.""" return [c[0] for c in self._deferred_captures.values()] @property def deferred_internal_captures(self): """List of nest of placeholders which at call time will be fed.""" return [c[1] for c in self._deferred_captures.values()] @property def variable_captures(self): """Map of python object ids of variables to variables which are captured.""" return { id(self._captures[id(v)][1]): v for v in self.variables if id(v) in self._captures } def mark_as_unsaveable(self, error_message): """Marks this FuncGraph as unsaveable. Any attempts to export this FuncGraph will raise an error with the specified message. Args: error_message: List or string containing the error message to be raised when saving this FuncGraph to SavedModel. """ self._saveable = False if isinstance(error_message, str): error_message = [error_message] self._saving_errors.update(error_message) @property def saveable(self): """Returns whether this FuncGraph is saveable.""" return self._saveable @property def saving_errors(self): """Returns set of errors preventing this FuncGraph from being saved.""" return self._saving_errors def _add_scope_exit_callback(self, fn): """Add a function to call when this graph exits the default scope.""" if not callable(fn): raise TypeError("fn is not callable: {}".format(fn)) if self._scope_exit_callbacks is None: raise RuntimeError( "Attempting to add a scope exit callback, but the default graph is " "not the context scope graph. Did you forget to call " "'with graph.as_default(): ...'?") self._scope_exit_callbacks.append(fn) def func_graph_from_py_func(name, python_func, args, kwargs, signature=None, func_graph=None, autograph=False, autograph_options=None, add_control_dependencies=True, arg_names=None, op_return_value=None, collections=None, capture_by_value=None, override_flat_arg_shapes=None): """Returns a `FuncGraph` generated from `python_func`. Args: name: an identifier for the function. python_func: the Python function to trace. args: the positional args with which the Python function should be called; ignored if a signature is provided. kwargs: the keyword args with which the Python function should be called; ignored if a signature is provided. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. When a signature is provided, `args` and `kwargs` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use this graph else a new one is built and returned. autograph: whether to use autograph to compile `python_func`. See https://www.tensorflow.org/guide/autograph for more information. autograph_options: additional knobs to control when `autograph=True`. See https://www.tensorflow.org/guide/autograph for more information. add_control_dependencies: If True, automatically adds control dependencies to ensure program order matches execution order and stateful ops always execute. arg_names: Optional list of argument names, used to give input placeholders recognizable names. op_return_value: Optional. A Tensor. If set and `python_func` returns Operations, those return values will be replaced with this value. If not set, returning an Operation triggers an error. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write to) the outer graph's collections that are not allowlisted, and both read and write to the outer graph's collections that are allowlisted. The current allowlisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. capture_by_value: An optional boolean. If True, the func graph will capture Variables by value instead of reference. By default inherit from outer graphs, and failing that will default to False. override_flat_arg_shapes: An optional list of instances that are either `None` or `TensorShape`. The length must match that of `nest.flatten((args, kwargs), expand_composites=True)`. The entries containing value `None` must match entries in flattened arguments containing non-tensors, while entries containing a `TensorShape` must match entries in the flattened arguments containing tensors. Returns: A FuncGraph. Raises: TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. ValueError: If both `signature` and `override_flat_arg_shapes` are passed in. """ if op_return_value is not None: assert isinstance(op_return_value, ops.Tensor), op_return_value if func_graph is None: func_graph = FuncGraph(name, collections=collections, capture_by_value=capture_by_value) assert isinstance(func_graph, FuncGraph) if add_control_dependencies: deps_control_manager = auto_control_deps.AutomaticControlDependencies() else: deps_control_manager = ops.NullContextmanager() with func_graph.as_default(), deps_control_manager as deps_ctx: current_scope = variable_scope.get_variable_scope() default_use_recource = current_scope.use_resource current_scope.set_use_resource(True) if signature is not None and override_flat_arg_shapes is not None: raise ValueError( "Passed both signature and override_flat_arg_shapes: %s and %s." % (signature, override_flat_arg_shapes)) if signature is not None: args = signature kwargs = {} # Creates and names placeholders for all arguments. if override_flat_arg_shapes is not None: flat_args = nest.flatten(args, expand_composites=True) arg_shapes = override_flat_arg_shapes[:len(flat_args)] kwarg_shapes = override_flat_arg_shapes[len(flat_args):] else: arg_shapes = None kwarg_shapes = None func_args = _get_defun_inputs_from_args( args, arg_names, flat_shapes=arg_shapes) func_kwargs = _get_defun_inputs_from_kwargs( kwargs, flat_shapes=kwarg_shapes) # Convert all Tensors into TensorSpecs before saving the structured inputs. # If storing pure concrete functions that are not called through polymorphic # functions, we don't have access to FunctionSpec, so we need to call the # TensorSpecs by their `arg_names` for later binding. func_graph.structured_input_signature = ( convert_structure_to_signature(func_args, arg_names), convert_structure_to_signature(func_kwargs)) flat_func_args = nest.flatten(func_args, expand_composites=True) flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True) # Temporarily set inputs to allow graph building code to inspect # them. Reassigned below. func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs if isinstance(arg, ops.Tensor)] # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, flat_func_args, expand_composites=True) func_kwargs_before = nest.pack_sequence_as( func_kwargs, flat_func_kwargs, expand_composites=True) def convert(x): """Converts a function output to a Tensor.""" if x is None: return None if op_return_value is not None and isinstance(x, ops.Operation): # TODO(b/79881896): we currently can't capture external control deps, so # this won't work if x needs to be captured (i.e. if python_func returns # captured Operations). with ops.control_dependencies([x]): x = array_ops.identity(op_return_value) elif not isinstance(x, tensor_array_ops.TensorArray): try: x = ops.convert_to_tensor_or_composite(x) except (ValueError, TypeError): raise TypeError( "To be compatible with tf.eager.defun, Python functions " "must return zero or more Tensors; in compilation of %s, found " "return value of type %s, which is not a Tensor." % (str(python_func), type(x))) if add_control_dependencies: x = deps_ctx.mark_as_return(x) return x try: if autograph: from tensorflow.python import autograph # pylint: disable=g-import-not-at-top _, original_func = tf_decorator.unwrap(python_func) def wrapper(*args, **kwargs): """Calls a converted version of original_func.""" # TODO(mdan): Push this block higher in tf.function's call stack. try: return autograph.converted_call( original_func, args, kwargs, options=autograph.ConversionOptions( recursive=True, optional_features=autograph_options, user_requested=True, )) except Exception as e: # pylint:disable=broad-except if hasattr(e, "ag_error_metadata"): raise e.ag_error_metadata.to_exception(e) else: raise # Wrapping around a decorator allows checks like tf_inspect.getargspec # to be accurate. converted_func = tf_decorator.make_decorator(original_func, wrapper) python_func = tf_decorator.rewrap(python_func, original_func, converted_func) else: _, original_func = tf_decorator.unwrap(python_func) func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors, CompositeTensors, # TensorArrays and `None`s. func_outputs = nest.map_structure(convert, func_outputs, expand_composites=True) check_mutation(func_args_before, func_args, original_func) check_mutation(func_kwargs_before, func_kwargs, original_func) finally: current_scope.set_use_resource(default_use_recource) # Variables in `func_args`, `func_kwargs` should be explicit inputs # to the function, not captured inputs. graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access arg_variables = object_identity.ObjectIdentitySet() inputs = [] for arg in (nest.flatten(func_args, expand_composites=True) + nest.flatten(func_kwargs, expand_composites=True)): if isinstance(arg, resource_variable_ops.BaseResourceVariable): # Even if an argument variable was not used in the function, we've # already manually captured the resource Tensor when creating argument # placeholders. resource_placeholder = func_graph.pop_capture(arg.handle) if resource_placeholder is None: continue arg_variables.add(arg) inputs.append(resource_placeholder) elif isinstance(arg, ops.Tensor): inputs.append(arg) variables = [v for v in graph_variables if v not in arg_variables] func_graph.inputs = ( inputs + func_graph.internal_captures + nest.flatten( func_graph.deferred_internal_captures, expand_composites=True)) func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( func_graph.capture(x) for x in flatten(func_graph.structured_outputs) if x is not None) func_graph.variables = variables if add_control_dependencies: func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) func_graph.collective_manager_ids_used = ( deps_control_manager.collective_manager_ids_used) return func_graph def maybe_captured(tensor): """If t is a captured value placeholder, returns the original captured value. Args: tensor: Tensor. Returns: A tensor, potentially from a different Graph/FuncGraph. """ if (not isinstance(tensor, ops.EagerTensor) and tensor.op.graph.building_function and tensor.op.type == "Placeholder"): for input_t, placeholder_t in tensor.op.graph.captures: if tensor == placeholder_t: return maybe_captured(input_t) # pylint: enable=protected-access return tensor def device_stack_has_callable(device_stack): """Checks whether a device stack contains a callable.""" return any(callable(spec._device_name_or_function) # pylint: disable=protected-access for spec in device_stack.peek_objs()) def check_mutation(n1, n2, func): """Check if two list of arguments are exactly the same.""" func_name = getattr(func, "__name__", func) errmsg = ("{}() should not modify its Python input arguments." " Check if it modifies any lists or dicts passed as" " arguments. Modifying a copy is allowed.".format(func_name)) try: # TODO(mdan): Compare more robustly so that argument names can be reported. nest.assert_same_structure(n1, n2, expand_composites=True) except ValueError: raise ValueError(errmsg) for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True), nest.flatten(n2, expand_composites=True)): if arg1 is not arg2: raise ValueError(errmsg) # TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. def flatten(sequence): """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. Args: sequence: A nested structure of Tensors, CompositeTensors, and TensorArrays. Returns: A list of tensors. """ flat_sequence = nest.flatten(sequence, expand_composites=True) return [ item.flow if isinstance(item, tensor_array_ops.TensorArray) else item for item in flat_sequence] # TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. def pack_sequence_as(structure, flat_sequence): """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. Args: structure: The structure to pack into. May contain Tensors, CompositeTensors, or TensorArrays. flat_sequence: An iterable containing tensors. Returns: A nested structure. Raises: AssertionError if `structure` and `flat_sequence` are not compatible. """ flat_sequence = list(flat_sequence) flattened_structure = nest.flatten(structure, expand_composites=True) if len(flattened_structure) != len(flat_sequence): raise ValueError("Mismatch in element count") for i in range(len(flat_sequence)): if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( old_ta=flattened_structure[i], flow=flat_sequence[i]) return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) def _create_substitute_placeholder(value, name=None, dtype=None, shape=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. if shape is None: shape = value.shape with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=shape, name=name) custom_gradient.copy_handle_data(value, placeholder) return placeholder def _get_defun_inputs_from_args(args, names, flat_shapes=None): """Maps Python function positional args to graph-construction inputs.""" return _get_defun_inputs( args, names, structure=args, flat_shapes=flat_shapes) def _get_composite_tensor_spec(x): """Returns the TypeSpec for x if it's a composite tensor, or x otherwise.""" return (x._type_spec # pylint: disable=protected-access if isinstance(x, composite_tensor.CompositeTensor) else x) def _get_defun_inputs(args, names, structure, flat_shapes=None): """Maps python function args to graph-construction inputs. Args: args: A flat list of user-specified arguments. names: A list of strings with user-specified argument names, same length as `args`. May be `None`, in which case a generic name is used. structure: The original argument list or dictionary. flat_shapes: A flat list of values that are either `None` or instances of `TensorShape`. If provided, then length must match that of `nest.flatten(args, expand_composites=True)`; and locations where `args` are instances of `Tensor` must have a corresponding `TensorShape` in `flat_shapes`. May be `None`, in which case exact shapes are read directly from the args. Returns: Placeholders with the same structure as `structure`. Raises: RuntimeError: if `flat_shapes` is provided, but `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`. RuntimeError: if a shape from `flat_shapes` is not None for an argument that is not a `Tensor`, `TensorSpec`, or `ResourceVariable`. """ func_graph = ops.get_default_graph() function_inputs = [] if names is None: names = [None] * len(args) if flat_shapes is None: shapes_iter = itertools.repeat(None) else: len_flat_args = len(nest.flatten(args, expand_composites=True)) if len_flat_args != len(flat_shapes): raise RuntimeError( "Length of fully flat shapes (%d) must match that of " "flatten(args) (%d). args: %s, flat_shapes: %s" % (len(flat_shapes), len_flat_args, args, flat_shapes)) shapes_iter = iter(flat_shapes) for arg_value, name in zip(args, names): # Replace any composite tensors with their TypeSpecs. This is important # for ensuring that shape information that's not preserved by the TypeSpec # (such as the number of values in a SparseTensor) gets properly masked. arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value) flattened = nest.flatten(arg_value, expand_composites=True) for arg in flattened: # We have a shape entry for each arg, regardless of whether it's a real # Tensor or not. For non-tensor entries it should be None. shape = next(shapes_iter) if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): arg_is_spec = isinstance(arg, tensor_spec.TensorSpec) if arg_is_spec and arg.name: requested_name = arg.name else: requested_name = name placeholder_shape = shape if shape is not None else arg.shape try: placeholder = graph_placeholder( arg.dtype, placeholder_shape, name=requested_name) except ValueError: # Sometimes parameter names are not valid op names, so fall back to # unnamed placeholders. placeholder = graph_placeholder(arg.dtype, placeholder_shape) if not arg_is_spec: custom_gradient.copy_handle_data(arg, placeholder) if name is not None: # Record the requested/user-specified name in case it's different than # the uniquified name, for validation when exporting signatures. placeholder.op._set_attr( # pylint: disable=protected-access "_user_specified_name", attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) function_inputs.append(placeholder) elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, resource_variable_ops.VariableSpec)): if isinstance(arg, resource_variable_ops.VariableSpec): name = arg.name or name with func_graph.outer_graph.as_default(): placeholder = graph_placeholder(dtypes.resource, arg.shape, name=name) arg = resource_variable_ops.BaseResourceVariable( name=name, shape=arg.shape, dtype=arg.dtype, handle=placeholder, handle_name=name) # Capture arg variables to create placeholders for them. These will be # removed as captures after the function is traced (since otherwise we'd # just add it back with a new placeholder when the variable was # referenced). placeholder = func_graph.capture(arg.handle, name=name) placeholder.op._set_attr( # pylint: disable=protected-access "_user_specified_name", attr_value_pb2.AttrValue(s=compat.as_bytes(name))) function_inputs.append(arg) else: if shape is not None: raise RuntimeError( "Expected provided shape override to be None for arg that isn't " "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" % (arg, shape, args)) function_inputs.append(arg) return nest.pack_sequence_as(structure, function_inputs, expand_composites=True) def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): """Maps Python function keyword args to graph-construction inputs.""" if kwargs: names, args = zip(*sorted(kwargs.items())) else: names = [] args = [] return _get_defun_inputs( args, names, structure=kwargs, flat_shapes=flat_shapes) def dismantle_func_graph(func_graph): """Removes reference cycles in `func_graph` FuncGraph. Helpful for making sure the garbage collector doesn't need to run when the FuncGraph goes out of scope, e.g. in tests using defun with @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). Args: func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable after this function. """ func_graph.clear_captures() ops.dismantle_graph(func_graph) def override_func_graph_name_scope(func_graph, name_scope): func_graph._name_stack = name_scope # pylint: disable=protected-access