"""Manages a graph of Trackable objects.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import weakref from tensorflow.core.protobuf import trackable_object_graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training.saving import saveable_object as saveable_object_lib from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import tracking from tensorflow.python.util import object_identity _ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. # Keyword for identifying that the next bit of a checkpoint variable name is a # slot name. Checkpoint names for slot variables look like: # # /<_OPTIMIZER_SLOTS_NAME>// # # Where is a full path from the checkpoint root to the # variable being slotted for. _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" # Keyword for separating the path to an object from the name of an # attribute in checkpoint names. Used like: # /<_OBJECT_ATTRIBUTES_NAME>/ _OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" def _escape_local_name(name): # We need to support slashes in local names for compatibility, since this # naming scheme is being patched in to things like Layer.add_variable where # slashes were previously accepted. We also want to use slashes to indicate # edges traversed to reach the variable, so we escape forward slashes in # names. return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) .replace(r"/", _ESCAPE_CHAR + "S")) def _object_prefix_from_path(path_to_root): return "/".join( (_escape_local_name(trackable.name) for trackable in path_to_root)) def _slot_variable_naming_for_optimizer(optimizer_path): """Make a function for naming slot variables in an optimizer.""" # Name slot variables: # # /<_OPTIMIZER_SLOTS_NAME>// # # where is exactly the checkpoint name used for the original # variable, including the path from the checkpoint root and the local name in # the object which owns it. Note that we only save slot variables if the # variable it's slotting for is also being saved. optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) def _name_slot_variable(variable_path, slot_name): """With an optimizer specified, name a slot variable.""" return (variable_path + optimizer_identifier + _escape_local_name(slot_name)) return _name_slot_variable def _serialize_slot_variables(trackable_objects, node_ids, object_names): """Gather and name slot variables.""" non_slot_objects = list(trackable_objects) slot_variables = object_identity.ObjectIdentityDictionary() for trackable in non_slot_objects: if (isinstance(trackable, optimizer_v1.Optimizer) # TODO(b/110718070): Fix Keras imports. # Note: dir() is used rather than hasattr() here to avoid triggering # custom __getattr__ code, see b/152031870 for context. or "_create_or_restore_slot_variable" in dir(trackable)): naming_scheme = _slot_variable_naming_for_optimizer( optimizer_path=object_names[trackable]) slot_names = trackable.get_slot_names() for slot_name in slot_names: for original_variable_node_id, original_variable in enumerate( non_slot_objects): try: slot_variable = trackable.get_slot( original_variable, slot_name) except (AttributeError, KeyError): slot_variable = None if slot_variable is None: continue slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access # TODO(allenl): Gather dependencies of slot variables. raise NotImplementedError( "Currently only variables with no dependencies can be saved as " "slot variables. File a feature request if this limitation " "bothers you.") if slot_variable in node_ids: raise NotImplementedError( "A slot variable was re-used as a dependency of a " "Trackable object. This is not currently allowed. File a " "feature request if this limitation bothers you.") checkpoint_name = naming_scheme( variable_path=object_names[original_variable], slot_name=slot_name) object_names[slot_variable] = checkpoint_name slot_variable_node_id = len(trackable_objects) node_ids[slot_variable] = slot_variable_node_id trackable_objects.append(slot_variable) slot_variable_proto = ( trackable_object_graph_pb2.TrackableObjectGraph .TrackableObject.SlotVariableReference( slot_name=slot_name, original_variable_node_id=original_variable_node_id, slot_variable_node_id=slot_variable_node_id)) slot_variables.setdefault(trackable, []).append( slot_variable_proto) return slot_variables class ObjectGraphView(object): """Gathers and serializes an object graph.""" def __init__(self, root, saveables_cache=None, attached_dependencies=None): """Configure the graph view. Args: root: A `Trackable` object whose variables (including the variables of dependencies, recursively) should be saved. May be a weak reference. saveables_cache: A dictionary mapping `Trackable` objects -> attribute names -> SaveableObjects, used to avoid re-creating SaveableObjects when graph building. attached_dependencies: Dependencies to attach to the root object. Used when saving a Checkpoint with a defined root object. """ self._root_ref = root self._saveables_cache = saveables_cache self._attached_dependencies = attached_dependencies def list_dependencies(self, obj): # pylint: disable=protected-access obj._maybe_initialize_trackable() dependencies = obj._checkpoint_dependencies # pylint: enable=protected-access if obj is self.root and self._attached_dependencies: dependencies = dependencies.copy() dependencies.extend(self._attached_dependencies) return dependencies @property def saveables_cache(self): """Maps Trackable objects -> attribute names -> list(SaveableObjects). Used to avoid re-creating SaveableObjects when graph building. None when executing eagerly. Returns: The cache (an object-identity dictionary), or None if caching is disabled. """ return self._saveables_cache @property def attached_dependencies(self): """Returns list of dependencies that should be saved in the checkpoint. These dependencies are not tracked by root, but are in the the checkpoint. This is defined when the user creates a Checkpoint with both root and kwargs set. Returns: A list of TrackableReferences. """ return self._attached_dependencies @property def root(self): if isinstance(self._root_ref, weakref.ref): derefed = self._root_ref() assert derefed is not None return derefed else: return self._root_ref def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) path_to_root = object_identity.ObjectIdentityDictionary() path_to_root[self.root] = () while to_visit: current_trackable = to_visit.popleft() if isinstance(current_trackable, tracking.NotTrackable): raise NotImplementedError( ("The object %s does not support object-based saving. File a " "feature request if this limitation bothers you. In the meantime, " "you can remove the dependency on this object and save everything " "else.") % (current_trackable,)) bfs_sorted.append(current_trackable) for name, dependency in self.list_dependencies(current_trackable): if dependency not in path_to_root: path_to_root[dependency] = ( path_to_root[current_trackable] + ( base.TrackableReference(name, dependency),)) to_visit.append(dependency) return bfs_sorted, path_to_root def _add_attributes_to_object_graph( self, trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for checkpoint_id, (trackable, object_proto) in enumerate( zip(trackable_objects, object_graph_proto.nodes)): assert node_ids[trackable] == checkpoint_id object_name = object_names[trackable] if object_map is None: object_to_save = trackable else: object_to_save = object_map.get(trackable, trackable) if self._saveables_cache is not None: cached_attributes = self._saveables_cache.setdefault(object_to_save, {}) else: cached_attributes = None for name, saveable_factory in ( object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) if cached_attributes is None: saveables = None else: saveables = cached_attributes.get(name, None) if saveables is not None: for saveable in saveables: if attribute.checkpoint_key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, attribute.checkpoint_key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable,) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple(saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=attribute.checkpoint_key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if attribute.checkpoint_key not in saveable.name: raise AssertionError( ("The object %s produced a SaveableObject with name '%s' for " "attribute '%s'. Expected a name containing '%s'.") % (trackable, name, saveable.name, attribute.checkpoint_key)) if cached_attributes is not None: cached_attributes[name] = saveables optional_restore = None for saveable in saveables: if optional_restore is None: optional_restore = saveable.optional_restore else: optional_restore = optional_restore and saveable.optional_restore if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( ("The object %s tried to feed a value for the Tensor %s " "when saving, but another object is already feeding a " "value.") % (trackable, new_feed_key)) feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) if optional_restore is None: optional_restore = False attribute.optional_restore = optional_restore return named_saveable_objects, feed_additions def _fill_object_graph_proto(self, trackable_objects, node_ids, slot_variables, object_graph_proto=None): """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" if object_graph_proto is None: object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) for checkpoint_id, trackable in enumerate(trackable_objects): assert node_ids[trackable] == checkpoint_id object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(trackable, ())) for child in self.list_dependencies(trackable): child_proto = object_proto.children.add() child_proto.node_id = node_ids[child.ref] child_proto.local_name = child.name return object_graph_proto def _serialize_gathered_objects(self, trackable_objects, path_to_root, object_map=None, call_with_mapped_captures=None): """Create SaveableObjects and protos for gathered objects.""" object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): object_names[obj] = _object_prefix_from_path(path) node_ids = object_identity.ObjectIdentityDictionary() for node_id, node in enumerate(trackable_objects): node_ids[node] = node_id slot_variables = _serialize_slot_variables( trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) object_graph_proto = self._fill_object_graph_proto( trackable_objects=trackable_objects, node_ids=node_ids, slot_variables=slot_variables) named_saveable_objects, feed_additions = ( self._add_attributes_to_object_graph( trackable_objects=trackable_objects, object_graph_proto=object_graph_proto, node_ids=node_ids, object_names=object_names, object_map=object_map, call_with_mapped_captures=call_with_mapped_captures)) return named_saveable_objects, object_graph_proto, feed_additions def serialize_object_graph(self): """Determine checkpoint keys for variables and build a serialized graph. Non-slot variables are keyed based on a shortest path from the root saveable to the object which owns the variable (i.e. the one which called `Trackable._add_variable` to create it). Slot variables are keyed based on a shortest path to the variable being slotted for, a shortest path to their optimizer, and the slot name. Returns: A tuple of (named_variables, object_graph_proto, feed_additions): named_variables: A dictionary mapping names to variable objects. object_graph_proto: A TrackableObjectGraph protocol buffer containing the serialized object graph and variable references. feed_additions: A dictionary mapping from Tensors to values which should be fed when saving. Raises: ValueError: If there are invalid characters in an optimizer's slot names. """ trackable_objects, path_to_root = self._breadth_first_traversal() return self._serialize_gathered_objects( trackable_objects, path_to_root) def frozen_saveable_objects(self, object_map=None, to_graph=None, call_with_mapped_captures=None): """Creates SaveableObjects with the current object graph frozen.""" trackable_objects, path_to_root = self._breadth_first_traversal() if to_graph: target_context = to_graph.as_default else: target_context = ops.NullContextmanager with target_context(): named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects( trackable_objects, path_to_root, object_map, call_with_mapped_captures) with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) named_saveable_objects.append( base.NoRestoreSaveable( tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) return named_saveable_objects def objects_ids_and_slot_variables(self): """Traverse the object graph and list all accessible objects. Looks for `Trackable` objects which are dependencies of `root_trackable`. Includes slot variables only if the variable they are slotting for and the optimizer are dependencies of `root_trackable` (i.e. if they would be saved with a checkpoint). Returns: A tuple of (trackable objects, object -> node id, slot variables) """ trackable_objects, path_to_root = self._breadth_first_traversal() object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): object_names[obj] = _object_prefix_from_path(path) node_ids = object_identity.ObjectIdentityDictionary() for node_id, node in enumerate(trackable_objects): node_ids[node] = node_id slot_variables = _serialize_slot_variables( trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) return trackable_objects, node_ids, slot_variables def list_objects(self): """Traverse the object graph and list all accessible objects.""" trackable_objects, _, _ = self.objects_ids_and_slot_variables() return trackable_objects