# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================== """A utility to trace tensor values on TPU.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import operator import os import os.path import sys import numpy as np import six from tensorflow.core.framework import summary_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import summary_ops_v2 as summary from tensorflow.python.ops import variable_scope from tensorflow.python.platform import analytics from tensorflow.python.platform import gfile from tensorflow.python.platform import remote_utils from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary_iterator from tensorflow.python.tpu import tensor_tracer_flags from tensorflow.python.tpu import tensor_tracer_report from tensorflow.python.tpu import tpu from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.training import training_util _DEVICE_TYPE_TPU = 'tpu' _DEVICE_TYPE_CPU = 'cpu' _TRACE_MODE_PART_TENSOR_SIZE = 3 _REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' _REASON_UNSAFE_OP = 'not-traced-unsafe-op' _REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op' _REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op' _REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow' _REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' _REASON_SKIP_SCALAR = 'not-traced-scalar' _REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' _REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' _REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' _REASON_SCALAR_GET_TRACED = 'traced-scalar' _REASON_TENSOR_GET_TRACED = 'traced-tensor' _REASON_USER_INCLUDED = 'traced-user-included' _REASON_USER_EXCLUDED = 'not-traced-user-excluded' _REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' _REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op' _OUTPUT_STREAM_ESCAPE = 'file://' _TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' _TRACE_FILE_NAME = 'trace.all' _COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' _COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 _TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' _TT_SNAPSHOT = 'tensor_tracer_snapshot' _REPLICA_ID_TAG = '#replica-id: ' _SKIP_REPORT_FILE = 'None' # Do not write report proto if --report_file=None _TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM _TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX _TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN _TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN _TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR _TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE _TT_SUMMARY_TAG = 'tensor_tracer_summary' _TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer' _TT_HOSTCALL_KEY = 'tensor_tracer_host_call' _TT_EVENT_FILE_SUFFIX = '.tensor_tracer' _TT_SUMMARY_MAX_QUEUE = 10 def set_parameters(tensor_tracer_params=None): """Enables tensor tracer and sets its parameters. Example usage: tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir', 'trace_mode': 'norm', 'report_file': '/usr/tmp/trace_dir/report.all'} tensor_tracer.set_parameters(tensor_tracer_parameters) This sets up the parameters for tensor tracer. A call to tensor tracer as below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be skipped as this call is hooked into tpu.rewrite. tt = tensor_tracer.TensorTracer() loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss) Args: tensor_tracer_params: Tensor tracer parameter dictionary. Below gives examples of these parameters: See tensor_tracer_report.py for all parameters. - enable: If set, tensor tracer will be enabled. Calling enable_tensor_tracer automatically adds this parameters. - trace_mode: The trace_mode to be used by tensor tracer. These include: - summary: Collects multiple statistics for traced tensors, and writes them a summary file that can be visualized using tensorboard. This mode currently only works for TPUEstimator. It can be also be used for other models, but outfeed must be handled by the user. - norm: Collects norm of each traced tensor and writes them into a text file pointed by 'trace_dir' flag. (Default mode). - nan-inf: Checks the existince of NaNs and Infs in the tensor, and writes a boolean value to a text file pointed by 'trace_dir' flag. Note that 'norm' mode can also capture this information with more numerical info. - max-abs: Collects the absolute max for each traced tensors and writes it into a text file pointed by 'trace_dir' flag. - full-tensor: Writes the full tensor content of the traced tensors into a text file pointed by 'trace_dir' flag. - part-tensor: Writes a part of the tensor content of the traced tensors into a text file pointed by 'trace_dir' flag. - full_tensor_summary: Writes the full tensors as binary event files. The outputs can be read using: trace = tensor_tracer.read_tensor_tracer_event_file(event_file_path) - report_file: Path to the metadata file that is written during graph construction. If not set, metadata will be printed to stdout during graph construction. - trace_dir: Path where the execution traces will be written during the graph execution. If not set, trace will be printed to stderr. - trace_level: Tensor tracer aims to trace everything it can. This introduces some overhead on graph execution and graph compilation times. Using trace_level parameter, it is possible to trace operation based on their priorities. For example, - trace_level=7 is the highest trace_level, in which every op is traced. - trace_level=6 will skip constant operations such as tf.constant. - trace_level=5 will skip less important ops such as tf.identities. - The default trace_level=3, that will skip concat ops, or random number generators. - To reduce the graph compile time overhead, trace_level can be set to 0, that will skip additions, and substractions, and multiplications as well. - excluded_opnames: If set, any matching op name will not be traced. excluded_opnames can be set as a regular expression. E.g, excluded_opnames=.* will exclude everything. - excluded_optypes: If set, any matching op type will not be traced. excluded_optypes can be set as a regular expression. E.g, excluded_optypes=.* will exclude everything. excluded_optypes=MatMul will exclude all MatMul ops from tracing. - included_opnames: If set, any matching op name will be forced to be traced. included_opnames can be set as a regular expression. E.g, '--included_opnames=some_op --excluded_opname=*.' will only trace some_op. - included_optypes: If set, any matching op type will be forced to be traced. included_optypes can be set as a regular expression. E.g, '--included_optypes=some_op_type --excluded_optypes=*.' will trace only the ops with type 'some_op_type' Advanced Flags: - trace_scalar: Scalar values are not traced by default. If this flag is set, scalar values will also be traced. - op_range: In the form of '%d:%d' that limits the tracing to the ops within this limit. --op_range='5:10' will trace only the ops that have topological order between 5-10. - submode: 'brief' or 'detailed'. If the trace mode is not compact, brief mode will print only the id of each traced tensor to save some space. 'detailed' mode prints the full tensor name. - use_fingerprint_subdirectory: The trace directory will be chosen as using the fingerprint of the trace metadata under the provided trace_dir. """ flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE if tensor_tracer_params: for key, value in tensor_tracer_params.items(): flags += ' --%s=%s' % (key, value) os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = flags def op_priority(op_type): """Returns the priority of the op. If the priority of the op is k, it will be traced if trace_level>=k. Args: op_type: String name of the operation type. Returns: Integer value corresponding the priority of the op. """ if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range', 'VariableShape', 'Fill', 'OneHot', 'ShapeN'): # Lowest priority ops, e.g., constant ops across different steps, # They will be traced only if trace_level>=7 return 7 if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient', 'PreventGradient', 'Squeeze'): # Operations without numerical effects. # They will be only if trace_level>=6 return 6 if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile', 'CollectivePermute', 'SplitV'): # Operations that merge or slice an input, will be traced if trace_level>=5 return 5 if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'): # Operations less likely to provide useful information, # will be traced if trace_level>=4 return 4 if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'): # Add operations that are less likely create any issues, will be traced # if trace_level>=3 (default=3) return 3 if op_type in ('Neg', 'Sub'): # Sub operations that are less likely create any issues, will be traced # trace_level>=2 return 2 if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select', 'Maximum', 'Mean', 'Variance'): # Multiplication and some other operations, will be traced if trace_level>=1 return 1 return 0 def read_tensor_tracer_event_file(event_file): """Reads the event file written by tensor tracer. This can be used to read the full tensors written into binary event files by by TensorTracer with trace_mode=full_tensor_summary. Example usage: result_dict = tensor_tracer.read_tensor_tracer_event_file(event_file_path) for step, tensor_dict in result_dict.items(): for tensor_name, full_tensor_content in tensor_dict.items(): logging.info(tensor_name, full_tensor_content) Args: event_file: Path to the event file that contains only tensor tracer events. Returns: An event dictionary in the form of {step_number: {tensor_name: tensor_content}} Raises: ValueError: If an unexpected trace is found. """ event_dict = {} for trace_event in summary_iterator.summary_iterator(event_file): # First event is an event with file_version: "brain.Event:2" if not trace_event.HasField('summary'): continue step = trace_event.step if step not in event_dict: event_dict[step] = {} if len(trace_event.summary.value) != 1: raise ValueError('Single step contains %d summary values,' ' expected 1.' % len(trace_event.summary.value)) tensor_value = trace_event.summary.value[0] tensor_name = tensor_value.tag real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim] tensor_content = np.frombuffer( tensor_value.tensor.tensor_content, dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype() ).reshape(real_shape) event_dict[step][tensor_name] = tensor_content return event_dict def trace_tensor(tensor, tracepoint_name=None): """Programmatic interface to trace a tensor with Tensor Tracer. Tensor Tracer, by default, traces all tensors in the execution. This function can be used to limit traced tensors. If this function is called for a subset of the tensors, only those will be traced. For example, Tensor Traacer will only trace c below. c = tf.MatMul(a, b) tensor_tracer.trace_tensor(c) d = tf.add(c, 1) Args: tensor: the tensor object for which the tracing is requested. tracepoint_name: an optional tensor tracepoint name string. A tracepoint name is an Tensor Tracer internal name for the tensor. It is useful when comparing equivalent traces from different models that have different tensor namings. Equivalent tensors (with different names) can be mapped to each other by assigning a common tracepoint_name. Returns: The provided tensor. """ if tracepoint_name is None: tracepoint_name = tensor.name tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, (tensor, tracepoint_name)) return tensor def keras_layer_tracepoint(layer, checkpoint_name): """An interface for adding the tensor outputs of a keras layer. Encapsulates trace_tensor. Args: layer: A keras layer. checkpoint_name: a string name for the checkpoint. This name has to be a unique name if used within model comparison. The tensors that have the same checkpoint identifier is compared in model comparison. Returns: The provided layer. """ try: outputs = layer.output if tensor_util.is_tensor(outputs): trace_tensor(outputs, '%s' % (checkpoint_name)) else: idx = 0 for output_tensor in outputs: if tensor_util.is_tensor(outputs): trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx)) idx += 1 except AttributeError: pass except RuntimeError: pass return layer class TensorTracer(object): """A software construct for tracing tensor values in a TF graph. This utility is disabled by default. It is hooked into tpu.rewrite, so it can easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as below without a code change. export TENSOR_TRACER_FLAGS="--enable=1" Below is the use example to enable it on CPUs or GPUs, or for more advance use cases on TPUs. a = x + 1 b = a * 2 rs = tf.reduce_sum(b) tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir', 'report_file: 'path/to/report/file'}) tt = tensor_tracer.TensorTracer() if on_tpu: rs = tt.trace_tpu(tf.get_default_graph(), tensor_fetches=rs) else: rs = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=rs) session.run(rs) If it is enabled, it will trace the output tensor values of selected Ops in the graph. It has two outputs: (1) the traces and (2) a report. The traces are dumped to a specified directory during the graph execution, while the report is dumped during the graph construction. By passing options via the env variable, users can change: (1) the trace mode (e.g., detecting NaN/Inf, printing partial or full tensor values) (2) which Ops to be traced (via op.name or op.type) (3) output trace file path. """ # The set of graphs that are rewritten by tensor tracer. _traced_graphs = set() @staticmethod def is_enabled(): """Returns True if TensorTracer is enabled.""" return tensor_tracer_flags.TTParameters().is_enabled() @staticmethod def check_device_type(device_type): """Checks if the given device type is valid.""" if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU): raise ValueError('Invalid device_type "%s"'%device_type) @staticmethod def check_trace_mode(device_type, trace_mode): """Checks if the given trace mode work on the given device type. Args: device_type: Device type, TPU, GPU, CPU. trace_mode: Tensor tracer trace mode. Raises: ValueError: If the given trace mode is not supported for the device. """ if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY: if device_type != _DEVICE_TYPE_TPU: raise ValueError('Device_type "%s" is not yet supported for ' 'trace mode "%s"' % (device_type, trace_mode)) @staticmethod def loop_cond_op(op): return op.type in ('LoopCond', 'RefLoopCond') @staticmethod def while_loop_op(op): """Returns true if op is one of the special ops of in a while loop. Args: op: A tf.Operation. Returns: True if the given op is one of [Switch, Merge, Enter, Exit, NextIteration, LoopCond], which are all building blocks for TF while loops. """ return (control_flow_util.IsLoopSwitch(op) or control_flow_util.IsLoopMerge(op) or control_flow_util.IsLoopEnter(op) or control_flow_util.IsLoopExit(op) or TensorTracer.loop_cond_op(op) or op.type in ('RefNextIteration', 'NextIteration')) @staticmethod def control_flow_op(op): """Returns true if op is one of the special ops of in a while loop. Args: op: A tf.Operation. Returns: True if the given op is one of [Switch, Merge, Enter, Exit, NextIteration, LoopCond], which are all building blocks for TF while loops. """ return (control_flow_util.IsSwitch(op) or control_flow_util.IsMerge(op)) @staticmethod def unsafe_op(op): """Returns True if this op is not safe to be traced.""" # Reasons for not including following op types: # Assign: cause incorrect result with CPU tracing. if op.type == 'Assign': return True return False @staticmethod def device_mismatch(device_type, op): if device_type == _DEVICE_TYPE_TPU: # pylint: disable=protected-access return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr # pylint: enable=protected-access return False @staticmethod def unsafe_scalar_trace(op): """Return true if scalar output tensor from Op is not safe to be traced.""" # Tracing the following causes cycle in the graph on TPU. if op.type in ('LoopCond', 'Enter', 'Merge', 'Const', 'Switch', 'Less', 'ReadVariableOp'): return True # Tracing the following will cause casting-issue # with the norm tracing mode or other compilation issues on CPU. if op.type in ('VarHandleOp', 'IteratorToStringHandle', 'IteratorGetNext', 'OneShotIterator', 'IteratorV2', 'MakeIterator', 'BatchDatasetV2', 'MapDataset', 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'): return True return False def _is_interesting_op(self, op): """Returns True if the given op is not an interesting one to be traced.""" return op_priority(op.type) <= self._parameters.trace_level @staticmethod def reason(op_idx, details): """Returns reason why the Op at op_idx is traced or not.""" return '%d %s'%(op_idx, details) def __init__(self): """Initializes a TensorTracer. Sets the various member fields from the flags (if given) or the defaults. """ self._replica_id = None self._tt_config = tensor_tracer_report.TensorTracerConfig() self._parameters = None self._host_call_fn = {} self._cache_variables = {} self._traced_op_names = set() self._report_proto = None self._temp_cache_var = [] self._report_proto_path = '' self._outmost_context = None def report_proto(self): """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes. Returns: A tensor_tracer.proto object. Raises: ValueError if called before tracing happens, or when trace mode is not summary or full_tensor_summary. """ if self._report_proto: return self._report_proto else: raise ValueError('Call to report_proto must be done after tracing.' 'Report proto only exists for ' 'trace_mode=[summary|full_tensor_summary]') def report_proto_path(self): """Getter for path where tensor_tracer.proto object should be written. Returns: A string path. """ return self._report_proto_path def _get_all_cache_variables(self): return self._cache_variables def _create_or_get_tensor_values_cache(self, cache_name, graph=None, shape=None, dtype=dtypes.float32): """Creates a variable as the cache to store intermediate tensor values. Args: cache_name: Name to be given to the cache (an instance of tf.variable). graph: Tensorflow graph. shape: A list of dimensions. dtype: Data type of created cache. Returns: A ref to newly created or existing cache with the given dimensions. Raises: ValueError: If missing a parameter to create the cache. """ def _escape_namescopes(variable_name): # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor" # and "foo_bar/mytensor". return variable_name.replace('/', '_').replace(':', '_') if cache_name not in self._cache_variables: if graph is None: raise ValueError('Graph must be provided at cache creation.') if shape is None: raise ValueError('shape must be provided at cache creation.') graph = graph or ops.get_default_graph() if dtype.is_integer: init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE) else: init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): self._cache_variables[cache_name] = variable_scope.get_variable( _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name), shape=shape, dtype=dtype, initializer=init_ops.constant_initializer(init_val), trainable=False, use_resource=True, collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES]) return self._cache_variables[cache_name] def _add_replica_id_to_graph(self): """Adds nodes for computing the replica ID to the graph.""" if self._tt_config.num_replicas: with ops.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. self._replica_id = tpu_ops.tpu_replicated_input( list(range(self._tt_config.num_replicas)), name='tt_replica_id') else: self._replica_id = 'unknown' def _inside_op_range(self, idx): """Return True if the given index is inside the selected range.""" if idx < self._parameters.op_range[0]: return False return (self._parameters.op_range[1] < 0 or idx <= self._parameters.op_range[1]) def _is_user_included_op(self, op): """Checks whether the op is included in the tensor tracer flags. Args: op: tf Operation Returns: True, if the op is included. An op is included if: - Its op name is given in included_opnames - Its op type is given in included_optypes - The op is at most _trace_ops_before_included hops before an included op - The op is at most _trace_ops_after_included hops after an included op """ for opname_re in self._parameters.included_opname_re_list: if opname_re.match(op.name): return True for optype_re in self._parameters.included_optype_re_list: if optype_re.match(op.type): return True return False def _is_user_excluded_op(self, op): for opname_re in self._parameters.excluded_opname_re_list: if opname_re.match(op.name): return True for optype_re in self._parameters.excluded_optype_re_list: if optype_re.match(op.type): return True return False def _signature_types(self): """Returns a dictionary holding the order of signatures in the cache for the selected trace mode.""" if self._parameters.trace_mode in set([ tensor_tracer_flags.TRACE_MODE_NAN_INF, tensor_tracer_flags.TRACE_MODE_NORM, tensor_tracer_flags.TRACE_MODE_MAX_ABS]): return {self._parameters.trace_mode: 0} if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: return self._parameters.summary_signatures return {} def _num_signature_dimensions(self): return len(self._signature_types()) def _use_temp_cache(self): """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable. Returns: A boolean, denoting whether to use a temporary cache or not. """ # If full tensors need to be stored tf.variables, then do not use temp # variables to store them. if self._use_tensor_buffer(): return False if self._use_tensor_values_cache(): return self._parameters.use_temp_cache_var else: # Temporary caches only replaces tf.Variables caches. If no cache is used # return False. return False def _use_tensor_values_cache(self): """Returns True if immediate tensors should be first saved to a cache.""" return self._parameters.use_compact_trace def _use_tensor_buffer(self): """Returns true if the whole tensor needs to be cached/buffered in memory.""" return (self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) def _merge_tensor_signatures(self, signatures): """Returns a tensor that merges the given signatures. Args: signatures: A dictionary of the signature updates from signature name to a tensor of dimension [1]. Returns: A tensor that concats the signature values in a predefined order. """ sorted_update = [] if self._num_signature_dimensions() > 1: signature_indices = self._signature_types() for _, val in sorted(signatures.items(), key=lambda item: signature_indices[item[0]]): sorted_update.append(val) updates = array_ops.stack( sorted_update, axis=0, name='merge_single_op_signatures') elif self._num_signature_dimensions() == 1: # Avoid stack operation if there is only a single signature. (_, val), = signatures.items() updates = val else: raise ValueError('Cannot merge 0 signatures.') return updates def _save_tensor_value_to_tmp_cache(self, cache_idx, updates): """Returns an op that will save the given updates to an entry in the cache. Args: cache_idx: The cache index of the tensor within the cache. updates: A dictionary of the signature updates from signature name to a tensor of dimension [1]. """ updates = self._merge_tensor_signatures(updates) updates = array_ops.reshape(updates, [self._num_signature_dimensions()]) self._temp_cache_var[cache_idx] = updates def _save_tensor_value_to_cache_op(self, cache_idx, updates): """Returns an op that will save the given updates to an entry in the cache. Args: cache_idx: The cache index of the tensor within the cache. updates: A dictionary of the signature updates. Returns: Cache update operation. """ # state_ops.scatter_update allows updates only along the first dimension. # Make a compact array by concatenating different signatures, and update # them all together. updates = self._merge_tensor_signatures(updates) updates = array_ops.reshape(updates, [1, self._num_signature_dimensions()]) indices = constant_op.constant([cache_idx]) cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG) return state_ops.scatter_update(cache, indices, updates).op def _snapshot_tensor(self, tensor): """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable. Args: tensor: tensor whose values will be stored in a new tf.Variable. Returns: An assignment operation. """ snapshot_variable = self._create_or_get_tensor_values_cache( tensor.name, tensor.op.graph, tensor.shape.as_list(), tensor.dtype) return state_ops.assign(snapshot_variable, tensor).op def _preprocess_traced_tensor(self, tensor): """Computes NAN/Norm/Max on TPUs before sending to CPU. Args: tensor: The tensor to be traced. Returns: A tensor that should be input to the trace_function. Raises: RuntimeError: If the trace mode is invalid. """ def _detect_nan_inf(tensor): """Trace function for detecting any NaN/Inf in the tensor.""" if tensor.dtype.is_floating: mask = math_ops.reduce_any( gen_math_ops.logical_or( gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) output_tensor = control_flow_ops.cond( mask, lambda: constant_op.constant([1.0]), lambda: constant_op.constant([0.0])) else: output_tensor = constant_op.constant([0.0]) return output_tensor def _compute_signature(tensor, tf_op, cast_to_f32=True): if cast_to_f32: tensor = math_ops.cast(tensor, dtypes.float32) output_tensor = tf_op(tensor) # Return type should be scalar. Set it if it does not have the # information. if not output_tensor.get_shape().is_fully_defined(): output_tensor = array_ops.reshape(output_tensor, []) return output_tensor def _show_size(tensor): # In order to check the size of a tensor. # Not all sizes are known at the compile time, also, different replicas # sometimes get different sizes of tensors. # Collect it here to be used in merging replica data. tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False) # Cast to float32, so that it can be placed into same cache with other # signatures. return math_ops.cast(tsize, dtypes.float32) def _show_max(tensor, cast_to_f32=True): # returns -inf for empty tensor return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32) def _show_min(tensor, cast_to_f32=True): # returns inf for empty tensor return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32) def _show_norm(tensor, cast_to_f32=True): # returns 0 for empty tensor return _compute_signature(tensor, linalg_ops.norm, cast_to_f32) def _show_mean_and_variance(tensor, cast_to_f32=True): """Returns the mean and variance of the given tensor.""" if cast_to_f32: tensor = math_ops.cast(tensor, dtypes.float32) # returns nan for empty tensor mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0]) # The shape has to be 1. Set it if it does not have the information. if not mean.get_shape().is_fully_defined(): mean = array_ops.reshape(mean, []) if not var.get_shape().is_fully_defined(): var = array_ops.reshape(var, []) return mean, var def _show_max_abs(tensor): tensor = math_ops.cast(tensor, dtypes.float32) output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) zero = constant_op.constant(0, dtypes.float32) output_tensor = gen_math_ops.maximum(zero, output_tensor) # The shape has to be 1. Set it if it does not have the information. output_tensor = array_ops.reshape(output_tensor, [1]) return output_tensor if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: return {self._parameters.trace_mode: _detect_nan_inf(tensor)} if (self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_PART_TENSOR): return {self._parameters.trace_mode: tensor} if (self._parameters.trace_mode in ( tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)): return {self._parameters.trace_mode: tensor} if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM: return {self._parameters.trace_mode: array_ops.reshape( _show_norm(tensor), [1])} if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS: return {self._parameters.trace_mode: _show_max_abs(tensor)} if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: tensor = math_ops.cast(tensor, dtypes.float32) result_dict = {} # Call mean and variance computation here to avoid adding the same nodes # twice. if (_TT_SUMMARY_MEAN in self._signature_types() or _TT_SUMMARY_VAR in self._signature_types()): mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False) for signature_name, _ in sorted(self._signature_types().items(), key=lambda x: x[1]): if signature_name == _TT_SUMMARY_NORM: signature_result_tensor = _show_norm(tensor, cast_to_f32=False) elif signature_name == _TT_SUMMARY_MAX: signature_result_tensor = _show_max(tensor, cast_to_f32=False) elif signature_name == _TT_SUMMARY_MIN: signature_result_tensor = _show_min(tensor, cast_to_f32=False) elif signature_name == _TT_SUMMARY_SIZE: signature_result_tensor = _show_size(tensor) elif signature_name == _TT_SUMMARY_MEAN: signature_result_tensor = mean elif signature_name == _TT_SUMMARY_VAR: signature_result_tensor = variance else: raise ValueError('Unknown signature type :%s.' % signature_name) result_dict[signature_name] = signature_result_tensor return result_dict raise RuntimeError( 'Tensor trace fun for %s is not yet implemented' % self._parameters.trace_mode) def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order): """Makes the tensor tracing function called by outside compilation. Args: tensor_name: name of the tensor being traced. tensor_trace_order: TensorTraceOrder object holding tensorname to id map. Returns: A function to be passed as the first argument to outside compilation. Raises: RuntimeError: If the trace mode is invalid. """ def _print_tensor(tensor_name, num_elements, tensor, output_tensor): """Prints a tensor value to a file. Args: tensor_name: name of the tensor being traced. num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. Returns: The same tensor passed via the "tensor" argument. Raises: ValueError: If tensor_name is not already in tensor_trace_order.tensorname_to_cache_idx. """ if self._parameters.is_brief_mode(): if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: raise ValueError( 'Tensor name %s is not in the tensorname_to_cache_idx' % tensor_name) msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name] else: msg = '"%s"' % tensor_name if self._parameters.trace_dir: output_path = os.path.join( self._parameters.trace_dir, _TRACE_FILE_NAME + self._get_outfile_suffix()) output_stream = _OUTPUT_STREAM_ESCAPE + output_path else: output_stream = sys.stderr return logging_ops.print_v2(msg, array_ops.shape(output_tensor), '@', self._replica_id, '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE, tensor, tensor) def _show_full_tensor(tensor): """Trace function for printing the entire tensor.""" return _print_tensor(tensor_name, -1, tensor, tensor) if (self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_PART_TENSOR): return _show_part_tensor # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF, # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are # performed within TPUs and only their results are transferred to CPU. # Simply, print the full tensor for these trace modes. if self._parameters.trace_mode in ( tensor_tracer_flags.TRACE_MODE_NAN_INF, tensor_tracer_flags.TRACE_MODE_NORM, tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, tensor_tracer_flags.TRACE_MODE_MAX_ABS, tensor_tracer_flags.TRACE_MODE_SUMMARY ): return _show_full_tensor raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._parameters.trace_mode) def _is_in_control_flow(self, op): """Returns true if the given op is inside a tf.cond or in tf.while_loop. Args: op: A tensorflow op that should be checked whether in control flow or not. Returns: A boolean value whether the op is in control flow or not. """ return control_flow_util.IsInCond(op) def _is_in_outmost_while_loop(self, op): """Returns true if the op is at the same level with the training loop. Returns false if the op is in an inner while loop or if it is outside of the training loop. Args: op: tf.Operation Returns: A boolean. """ ctxt = self._get_op_control_flow_context(op) outer_while_context = control_flow_util.GetContainingWhileContext(ctxt) return outer_while_context == control_flow_util.GetContainingWhileContext( self._outmost_context) def _should_trace_in_control_flow(self): """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop.""" # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY # forces the execution of the traced tensors. We should not trace the ops # that may not be executed due to control flow. if self._use_temp_cache(): return False elif self._tt_config.device_type == _DEVICE_TYPE_TPU: # On TPUs do not trace in control flow unless we use caches to store # intermediate values as calling outside compilation within an inner loop # causes errors. return self._use_tensor_values_cache() or self._use_tensor_buffer() return True def _skip_op(self, op_id, op, ops_in_exec_path, report_handler): """Returns True if we should not trace Op. Args: op_id: Topological index of the op. op: tf.Operation ops_in_exec_path: Set of operations that are in the execution path. report_handler: An instance of tensor_tracer_report.TTReportHandle. Returns: True if the op should not be traced, false otherwise. """ if TensorTracer.while_loop_op(op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP)) return True if TensorTracer.control_flow_op(op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP)) return True if TensorTracer.unsafe_op(op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP)) return True if TensorTracer.device_mismatch(self._tt_config.device_type, op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH)) return True if op not in ops_in_exec_path: report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED)) return True # TensorTracer will not trace the operations that are in an inner while loop # or tf.cond when a temporary cache is used. Temporary cache adds direct # data dependencies to traced operations, and needs a static number of # traced operations. For these cases, # - We do not know the number of slots required when there are inner while # loops. TensorTracer can only trace the result of a while loop. # - We do not know ahead of time which branch of the tf.cond # will be taken, so we avoid introducing data dependencies for the # operations inside a tf.cond. # - We also cannot have a data dependency to an operation in a different # while context. if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op): if not self._should_trace_in_control_flow(): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW)) return True if self._is_user_included_op(op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) return False if not self._inside_op_range(op_id): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE)) return True if not self._is_interesting_op(op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP)) return True if self._is_user_excluded_op(op): report_handler.instrument_op( op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) return True return False def _skip_tensor(self, op_id, out_tensor, report_handler): """Returns True if we should not trace out_tensor. Args: op_id: Topological index of the op producing tensor. out_tensor: tf.Tensor report_handler: An instance of tensor_tracer_report.TTReportHandle. Returns: True if the tensor should not be traced, false otherwise. """ # Skips a tensor if the tensor has a non-numeric type. # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) # because it also excludes tensors with dtypes, bool, and # float32_ref, which we actually want to trace. non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, dtypes.string]) if out_tensor.dtype in non_numeric_tensor_types: report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR)) return True # Skip a tensor if it feeds a special while loop op. if [consumer for consumer in out_tensor.consumers() if TensorTracer.while_loop_op(consumer)]: report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP)) return True if self._is_user_included_op(out_tensor.op): report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) return False if self._is_user_excluded_op(out_tensor.op): report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) return True if not out_tensor.get_shape().is_fully_defined(): # If trace mode is nan-inf, norm or max, then the tensor will be reduced # to a scalar before the outside compilation call. if self._parameters.trace_mode in ( tensor_tracer_flags.TRACE_MODE_NAN_INF, tensor_tracer_flags.TRACE_MODE_NORM, tensor_tracer_flags.TRACE_MODE_MAX_ABS, tensor_tracer_flags.TRACE_MODE_SUMMARY ): report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) return False else: report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE)) return True rank = len(out_tensor.shape) if rank < 1: # scalar if self._parameters.trace_scalar_ops: if TensorTracer.unsafe_scalar_trace(out_tensor.op): report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR)) return True else: report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED)) return False else: report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR)) return True else: # tensor report_handler.instrument_tensor( out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) return False def _filter_execution_path_operations(self, operations, fetches): """Returns the set of ops in the execution path to compute given fetches.""" # If no fetch provided, then return all operations. if fetches is None: return set(operations) # Convert to list, if a single element is provided. if not isinstance(fetches, (list, tuple)): fetches = [fetches] # If a tensor is given as fetch, convert it to op. op_fetches = [] for fetch in fetches: if isinstance(fetch, ops.Operation): op_fetches.append(fetch) elif isinstance(fetch, ops.Tensor): op_fetches.append(fetch.op) else: raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' %fetch) execution_path_operations = set(op_fetches) traverse_stack = list(op_fetches) while True: if not traverse_stack: break head_op = traverse_stack.pop() input_ops = [tensor_input.op for tensor_input in head_op.inputs] input_ops.extend(head_op.control_inputs) for input_op in input_ops: if input_op not in execution_path_operations: # Filter out loop condition operations, tracing them causes a cycle. # Trace only the loop-body. if TensorTracer.loop_cond_op(input_op): continue execution_path_operations.add(input_op) traverse_stack.append(input_op) return execution_path_operations def _determine_and_instrument_traced_tensors(self, graph_order, ops_in_exec_path, tensor_trace_points, report_handler): """Determines the tensors to trace and instruments the trace details. Args: graph_order: graph_order tuple containing graph (tf.graph), operations (list of operations), op_to_idx (op id mapping), (tensors) list of tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether there is a cycle in the graph), topological_order_or_cycle (list of ops in topological order or list of ops creating a cycle). ops_in_exec_path: Set of ops in the execution path. tensor_trace_points: Collection of programatic tensor trace points. report_handler: An instance of tensor_tracer_report.TTReportHandle. Returns: List of tensors to be traced. """ traced_tensors = [] checkpoint_operations = set([tensor.op for (tensor, _) in tensor_trace_points]) for op_id, op in enumerate(graph_order.operations): if checkpoint_operations and op not in checkpoint_operations: continue if self._skip_op(op_id, op, ops_in_exec_path, report_handler): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if not self._skip_tensor(op_id, out_tensor, report_handler): traced_tensors.append(out_tensor) return traced_tensors def _check_trace_files(self): """Checks if any requirements for trace files are satisfied.""" if not self._parameters.trace_dir: # traces will be written to stderr. No need to check trace files. return if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: # Output files are handled by tf.summary operations, no need to precreate # them. return if not gfile.Exists(self._parameters.trace_dir): file_io.recursive_create_dir(self._parameters.trace_dir) if not gfile.Exists(self._parameters.trace_dir): raise RuntimeError('Failed to create %s'%self._parameters.trace_dir) def _create_temp_cache(self, num_traced_tensors, num_signatures): """Creates a temporary cache with the given dimensions. Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops that have shape of [num_signatures]. Args: num_traced_tensors: Int, denoting total number of traced tensors. num_signatures: Int, denoting the number of statistics collected per tensors. """ init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, dtype=dtypes.float32, shape=[num_signatures]) self._temp_cache_var = [init_value for _ in range(num_traced_tensors)] def _determine_trace_and_create_report(self, graph, ops_in_exec_path): """Work needs to be done prior to TPU or CPU tracing. Args: graph: tf.graph ops_in_exec_path: Set of operations in the execution path. Returns: An instance of tensor_tracer_report.TensorTraceOrder, containing list of tensors to be traced with their topological order information. """ self._check_trace_files() graph_order = tensor_tracer_report.sort_tensors_and_ops(graph) tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION) report_handler = tensor_tracer_report.TTReportHandle() traced_tensors = self._determine_and_instrument_traced_tensors( graph_order, ops_in_exec_path, tensor_trace_points, report_handler) logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors)) tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order, traced_tensors) num_signatures = self._num_signature_dimensions() # Create a cache variable if compact_tracing is used. if num_signatures and self._use_tensor_values_cache(): if self._use_temp_cache(): self._create_temp_cache(len(traced_tensors), num_signatures) else: self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures]) if self._parameters.trace_mode in ( tensor_tracer_flags.TRACE_MODE_SUMMARY, tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY): self._report_proto = report_handler.create_report_proto( self._tt_config, self._parameters, tensor_trace_order, tensor_trace_points, self._signature_types()) if self._parameters.use_fingerprint_subdir: self._parameters.trace_dir = os.path.join( self._parameters.trace_dir, self._report_proto.fingerprint) logging.info('TensorTracer updating trace_dir to %s', self._parameters.trace_dir) self._report_proto_path = tensor_tracer_report.report_proto_path( self._parameters.trace_dir) if self._parameters.report_file_path != _SKIP_REPORT_FILE: report_handler.write_report_proto(self._report_proto, self._parameters) else: report_handler.create_report(self._tt_config, self._parameters, tensor_trace_order, tensor_trace_points) return tensor_trace_order def _create_host_call(self): return self._parameters.trace_mode in ( tensor_tracer_flags.TRACE_MODE_SUMMARY, tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream, tensor_trace_order): """Generates a print operation to print trace inspection. Args: cache: Tensor storing the trace results for the step. replica_id: Tensor storing the replica id of the running core. step_num: Step number. output_stream: Where to print the outputs, e.g., file path, or sys.stderr. tensor_trace_order: TensorTraceOrder object holding tensorname to id map. Returns: The Op to flush the cache to file. """ def _inspect_tensor(tensor): """Returns the text to be printed for inspection output.""" if (self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF): return control_flow_ops.cond( math_ops.greater(tensor, 0.0), lambda: 'has NaNs/Infs!', lambda: 'has no NaNs or Infs.') else: return tensor # Check if the cache includes any nan or inf if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: # Cache has 1s or 0s if the mode is NaN_INF step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0) else: # Cache has the actual numerics for other modes. step_has_nan_or_inf = math_ops.reduce_any( gen_math_ops.logical_or( gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache))) # Summarizing message for each step. step_error_message = control_flow_ops.cond( step_has_nan_or_inf, lambda: 'NaNs or Infs in the step!', lambda: 'No numerical issues have been found for the step.') # No need to print core numbers if the cache is merged already. if self._parameters.collect_summary_per_core: stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->', step_error_message, 'Printing tensors for mode:%s...' % self._parameters.trace_mode] else: stats = ['\n\n', 'step:', step_num, '-->', step_error_message, 'Printing tensors for mode:%s...' % self._parameters.trace_mode] for tensor_name, cache_idx in sorted( tensor_trace_order.tensorname_to_cache_idx.items(), key=lambda item: item[1]): if self._parameters.collect_summary_per_core: stats.extend([ '\n', 'core:', replica_id, ',', 'step:', step_num, ',', tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) else: stats.extend([ '\n', 'step:', step_num, ',', tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) return logging_ops.print_v2(*stats, summarize=-1, output_stream=output_stream) def _get_outfile_suffix(self): if remote_utils.is_remote_path(self._parameters.trace_dir): return remote_utils.get_appendable_file_encoding() else: return '' def _generate_flush_cache_op(self, num_replicas, on_tpu, tensor_trace_order): """Generates an Op that will flush the cache to file. Args: num_replicas: total number of replicas. on_tpu: if the graph is executed on TPU. tensor_trace_order: TensorTraceOrder object holding tensorname to id map. Returns: The Op to flush the cache to file. """ def _flush_fun(cache, replica_id, step_num): """Flushes the cache to a file corresponding to replica_id.""" def _f(file_index): """Generates a func that flushes the cache to a file.""" def _print_cache(): """Flushes the cache to a file.""" replica_str = ('%d' % file_index) if self._parameters.trace_dir: output_path = (os.path.join(self._parameters.trace_dir, _COMPACT_TRACE_FILE_PREFIX) + replica_str + self._get_outfile_suffix()) output_stream = _OUTPUT_STREAM_ESCAPE + output_path else: output_stream = sys.stderr new_step_line = _REPLICA_ID_TAG + replica_str print_ops = [] if self._parameters.inspect_trace: if self._num_signature_dimensions() > 1: raise ValueError('Inspecting multi signatures are not supported.') print_ops.append(self._inspect_summary_cache( cache=cache, replica_id=replica_id, step_num=step_num, output_stream=output_stream, tensor_trace_order=tensor_trace_order)) else: for i in range(self._num_signature_dimensions()): print_ops.append(logging_ops.print_v2( new_step_line, '\n', cache[:, i], '\n', summarize=-1, output_stream=output_stream)) with ops.control_dependencies(print_ops): return constant_op.constant(0).op return _print_cache def _eq(file_index): return math_ops.equal(replica_id, file_index) flush_op_cases = {} flush_op_cases[_eq(0)] = _f(0) for i in range(1, num_replicas): if on_tpu and not self._parameters.collect_summary_per_core: # If this is the case, the cache is already merged for all cores. # Only first core flushes the cache. flush_op_cases[_eq(i)] = control_flow_ops.no_op else: flush_op_cases[_eq(i)] = _f(i) # Each replica needs to determine where to write their output. # To do this, we check if replica_id is 0, then 1, ..., and then # num_replicas - 1 statically; and return the corresponding static file # name. We cannot simply set the file name in python, as replica_id is # only known during tf runtime, and we cannot create dynamic filenames. return control_flow_ops.case(flush_op_cases, exclusive=True) cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG) if self._use_temp_cache(): cache_val = cache else: cache_val = cache.value() if on_tpu: # If we do not need to collect traces for all cores, merge and aggregate # per core trace. if not self._parameters.collect_summary_per_core: cache_val = self.merge_caches_on_tpu(cache_val) cache_val = self.aggregate_global_cache(cache_val)[0] flush_op = tpu.outside_compilation( _flush_fun, cache_val, self._replica_id, array_ops.identity(training_util.get_or_create_global_step())) else: flush_op = _flush_fun(cache_val, self._replica_id, training_util.get_or_create_global_step()) if self._use_temp_cache(): with ops.control_dependencies([flush_op]): return constant_op.constant(0).op else: # Re-initialize the local cache variable. with ops.control_dependencies([flush_op]): reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, dtype=cache.dtype, shape=cache.shape) assign_op = state_ops.assign(cache, reset_value).op with ops.control_dependencies([assign_op]): return constant_op.constant(0).op def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu, tensor_trace_order): """Flushes the intermediate tensor values in the graph to the cache. Args: tensor_fetches: list of tensor results returned by the model_fn. op_fetches: list of ops that are returned by the model_fn, e.g., train_op. on_tpu: if the graph is executed on TPU. tensor_trace_order: TensorTraceOrder object holding tensorname to id map. Returns: An identical copy of tensor_fetches. """ # Add a dependency to op and tensor fetches to make sure that all tracing # ops are executed before flushing trace results. with ops.control_dependencies(op_fetches + [tensor.op for tensor in tensor_fetches]): flush_cache_op = self._generate_flush_cache_op( self._tt_config.num_replicas, on_tpu, tensor_trace_order) return control_flow_ops.tuple(tensor_fetches, control_inputs=[flush_cache_op]) def _process_tensor_fetches(self, tensor_fetches): """Check that tensor_fetches is not empty and have valid tensors.""" # If none or empty list. if tensor_fetches is None: raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 'None.') if not isinstance(tensor_fetches, (list, tuple)): tensor_fetches = [tensor_fetches] elif not tensor_fetches: raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 'empty list.') fetches = [] for fetch in tensor_fetches: if isinstance(fetch, ops.Tensor): fetches.append(fetch) else: raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) return fetches def _process_op_fetches(self, op_fetches): """Check that op_fetches have valid ops.""" if op_fetches is None: return [] if not isinstance(op_fetches, (list, tuple)): op_fetches = [op_fetches] fetches = [] for fetch in op_fetches: if isinstance(fetch, ops.Operation): fetches.append(fetch) elif isinstance(fetch, ops.Tensor): fetches.append(fetch.op) else: logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % fetch) return fetches def _convert_fetches_to_input_format(self, input_fetches, current_fetches): """Changes current_fetches' format, so that it matches input_fetches.""" if isinstance(input_fetches, ops.Tensor): if len(current_fetches) != 1: raise RuntimeError('Tensor tracer input/output fetches do not match.') return current_fetches[0] else: if len(current_fetches) != len(current_fetches): raise RuntimeError('Tensor tracer input/output fetches do not match.') elif isinstance(input_fetches, tuple): return tuple(current_fetches) else: return current_fetches def _get_op_control_flow_context(self, op): """Returns the control flow of the given op. Args: op: tf.Operation for which the control flow context is requested. Returns: op_control_flow_context: which the is control flow context of the given op. If the operation type is LoopExit, returns the outer control flow context. """ # pylint: disable=protected-access op_control_flow_context = op._control_flow_context # pylint: enable=protected-access if control_flow_util.IsLoopExit(op): op_control_flow_context = op_control_flow_context.outer_context return op_control_flow_context def merge_caches_on_tpu(self, local_tpu_cache_tensor): """Merges the given caches on tpu. Args: local_tpu_cache_tensor: A local tensor that needs to be merged by concanting data from other tpu cores. Returns: A merged tf.Tensor. Raises: RuntimeError: if there is no aggregate function defined for a signature. """ x = array_ops.broadcast_to( local_tpu_cache_tensor, shape=[self._tt_config.num_replicas] + local_tpu_cache_tensor.shape.as_list()) return tpu_ops.all_to_all( x, concat_dimension=0, split_dimension=0, split_count=self._tt_config.num_replicas) def aggregate_global_cache(self, global_tt_summary_cache): """Merges the given caches on tpu. Args: global_tt_summary_cache: The global tensor tracer summary cache tensor with shape (num_cores, num_traced_tensors, num_traced_signatures). First dimension corresponds to core_id, where global_tpu_cache_tensor[i] correspond to the local cache from core-i. Returns: An aggregated tf.Tensor. Raises: RuntimeError: if there is no aggregate function defined for a signature. """ # Merge only statistics tensor, if it is any other tensor we simply, # concatenate them. agg_fn_map = self._parameters.get_signature_to_agg_fn_map() signature_idx_map = self._signature_types() aggregation_result = [] for signature, idx in sorted(signature_idx_map.items(), key=operator.itemgetter(1)): if signature not in agg_fn_map: raise RuntimeError('No aggregation function is defined for ' 'signature %s.' % signature) # The dimensions of the statistics tensor is # num_cores x num_traced_tensors x num_signatures # value[:,:,idx] will return the portion of the tensor related # to signature. signature_tensor = global_tt_summary_cache[:, :, idx] # Merge it along the first (core) axis. agg_fn = agg_fn_map[signature] agg_tensor = agg_fn(signature_tensor, axis=0) aggregation_result.append(agg_tensor) # Merge results corresponding to different signatures merged_signatures = array_ops.stack(aggregation_result) # merged_signatures has dimensions # num_signatures x num_traced_tensors, transpose it so that it # will match with the original structure # num_traced_tensors x num_signatures. transposed_signatures = array_ops.transpose(merged_signatures) # Expand 1 more dimension so that it will match with the expected # structure num_cores x num_traced_tensors x num_signatures. return array_ops.expand_dims(transposed_signatures, axis=0) def _prepare_host_call_fn(self, processed_t_fetches, op_fetches): """Creates a host call function that will write the cache as tb summary. Args: processed_t_fetches: List of tensor provided to session.run. op_fetches: List of operations provided to session.run. Raises: ValueError if trace_dir is not set. """ if self._parameters.trace_dir is None: raise ValueError('Provide a trace_dir for tensor tracer in summary mode. ' '--trace_dir=/model/dir') def _write_cache(step, event_file_suffix=None, **kwargs): """Writes the given caches as tensor summary. Args: step: Step tensor with dimension [num_cores]. event_file_suffix: Event filename suffix tensor. **kwargs: The dictionary of tensors that needs to be written as summaries. Key and value pairs within kwargs correspond to the tag name, and tensor content that will be written using summary.write. The trace_modes that use this function are: - summary: In summary mode, kwargs includes a single (tag, content) pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache variable. The dimension of the signature_cache is: num_cores x num_traced_tensors x num_signatures. - full_tensor_summary: kwargs will include all traced tensors. Tag and content correspond to the name of the tensor, and its actual content. Returns: A tf.Operation that needs to be executed for the host call dependencies. Raises: RuntimeError: if there is no aggregate function defined for a signature. """ file_suffix = _TT_EVENT_FILE_SUFFIX if event_file_suffix is not None: file_suffix = string_ops.string_join([file_suffix, event_file_suffix], separator='.') # TODO(deveci): Parametrize max_queue, so that flushing op can be called # less frequently. # Setting max_queue to 100 appears to be safe even when the number of # iterations are much lower, as the destructor of the writer flushes it. summary_write_ops = [] summary_writer = summary.create_file_writer_v2( self._parameters.trace_dir, filename_suffix=file_suffix, max_queue=_TT_SUMMARY_MAX_QUEUE) with summary_writer.as_default(): summary_metadata = summary_pb2.SummaryMetadata( plugin_data=summary_pb2.SummaryMetadata.PluginData( plugin_name=_TT_TENSORBOARD_PLUGIN_NAME)) for key, value in kwargs.items(): # Check whether we need to compute aggregated statistics that merge # all cores statistics. if not self._parameters.collect_summary_per_core: # Merge only statistics tensor, if it is any other tensor we simply, # concatenate them. # Also, if there is only a single core (first dim. is 0), then skip # aggregation. if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1: value = self.aggregate_global_cache(value) with ops.control_dependencies([summary_writer.init()]): summary_write_ops.append(summary.write( _TT_SUMMARY_TAG + '/' + key, value, metadata=summary_metadata, step=step[0])) return control_flow_ops.group(summary_write_ops) step = array_ops.reshape(training_util.get_or_create_global_step(), [1]) self._host_call_fn = {} host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches] caches_to_write = {} with ops.control_dependencies(host_call_deps): all_caches = self._get_all_cache_variables() for cache_name, cache_variable in all_caches.items(): # Increase the cache rank by 1, so that when host call concatenates # tensors from different replicas, we can identify them with [core_id]. new_cache_shape = [1] new_cache_shape.extend(cache_variable.shape.as_list()) cache = array_ops.reshape(cache_variable, new_cache_shape) caches_to_write[cache_name] = cache # Add step to parameter dictionary. caches_to_write['step'] = step # Other options without adding step to parameter dictionary are # * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it # considers caches_to_write as a single parameter, rather than a keyword # parameters. # * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with # a syntax error. self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write) def host_call_deps_and_fn(self): return self._host_call_fn def get_traced_op_names(self): """Returns the set of traced op names.""" return self._traced_op_names def _trace_execution(self, graph, tensor_fetches, op_fetches=None, on_tpu=True): """Commong tracing function for both CPU and TPUs. The caller function should set device_type, num_replicas, num_replicas_per_host, num_hosts and replica_id before calling _trace_execution. Args: graph: the graph of Ops executed on the TPU. tensor_fetches: a (list,tuple,or a single object) of tensor fetches returned by model_fn given to session.run. Function must be provided with as least one tensor to fetch. op_fetches: A list of op fetches returned by model_fn given to session.run. op_fetches and tensor_fetches are used to determine the nodes that will be executed. Can be None. on_tpu: True if executing on TPU. Returns: tensor_fetches: an exact copy of tensor_fetches that has additional dependencies. Raises: RuntimeError: If tensor_fetches is None or empty. """ def _cast_unsupported_dtypes(tensor): """Casts tensor to a supported type.""" if tensor.dtype.__eq__(dtypes.int64): # outside-compilation doesn't support int64 input yet. return math_ops.cast(tensor, dtypes.int32) if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( dtypes.float16): # Since host can't handle bf16, convert tensor to f32. return math_ops.cast(tensor, dtypes.float32) return tensor trace_mode = self._parameters.trace_mode device_type = self._tt_config.device_type # pylint: disable=protected-access self._outmost_context = graph._get_control_flow_context() # pylint: enable=protected-access analytics.track_usage('tensor_tracer', [trace_mode, device_type]) TensorTracer.check_device_type(device_type) TensorTracer.check_trace_mode(device_type, trace_mode) # Check in_tensor_fetches, and op_fetches and convert them to lists. processed_t_fetches = self._process_tensor_fetches(tensor_fetches) op_fetches = self._process_op_fetches(op_fetches) all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches] # Filter out the operations that won't be executed. # if fetches=None, then ops_in_exec_path = set(operations) exec_op_set = self._filter_execution_path_operations(graph.get_operations(), all_fetches) # Write report file, and determine the traced tensors. tensor_trace_order = self._determine_trace_and_create_report( graph, exec_op_set) tensor_fetch_set = set(processed_t_fetches) tracing_ops = [] sorted_exec_op_list = list(exec_op_set) sorted_exec_op_list.sort(key=lambda op: op.name) # Trace ops only if they are in the execution path. for op in sorted_exec_op_list: for i in range(len(op.outputs)): out_tensor = op.outputs[i] tensor_name = out_tensor.name if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: continue self._traced_op_names.add(op.name) # Create the list of consumers before calling _preprocess_traced_tensor. # Otherwise, adding control input below, will introduce a cycle in the # graph. consumers = out_tensor.consumers() # Not all consumers may be in the exec path. Filter out the consumers # to keep the graph simpler. consumers = [cop for cop in consumers if cop in exec_op_set] # If there is no consumer of the tensor, there is no need to trace it; # unless the tensor itself is one of the fetches. is_a_fetched_tensor = out_tensor in tensor_fetch_set if (not consumers) and (not is_a_fetched_tensor): continue op_control_flow_context = self._get_op_control_flow_context(op) if op_control_flow_context: # pylint: disable=protected-access graph._set_control_flow_context(op_control_flow_context) # pylint: enable=protected-access processed_tensors = self._preprocess_traced_tensor(out_tensor) if on_tpu: for signature in processed_tensors.keys(): processed_tensors[signature] = _cast_unsupported_dtypes( processed_tensors[signature]) if self._use_tensor_values_cache(): # Use a small cache (either temp cache or tf local variable) to store # the characteristics of the tensor. if self._use_temp_cache(): cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] self._save_tensor_value_to_tmp_cache(cache_idx, processed_tensors) trace_op = None else: cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] trace_op = self._save_tensor_value_to_cache_op(cache_idx, processed_tensors) elif self._use_tensor_buffer(): if len(processed_tensors) != 1: raise RuntimeError('Multiple stats are only allowed in compact ' 'mode.') processed_out_tensor = list(processed_tensors.values())[0] # Store the whole tensor in a buffer. trace_op = self._snapshot_tensor(processed_out_tensor) else: def tpu_wrap_trace_fn(tensor, out_tensor_name): """Wraps the trace_fn with outside compilation if on TPUs.""" tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name, tensor_trace_order) if on_tpu: return tpu.outside_compilation(tensor_trace_fn, tensor) else: return tensor_trace_fn(tensor) if len(processed_tensors) != 1: raise RuntimeError('Multiple stats are only allowed in compact ' 'mode.') # Collecting multiple statistics are only supported in the summary # mode that uses compact format(self._use_tensor_values_cache = true). # Non-compact mode currently allows single stat per tensor. processed_out_tensor = six.next(six.itervalues(processed_tensors)) trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name) if op_control_flow_context: # pylint: disable=protected-access graph._set_control_flow_context(self._outmost_context) # pylint: enable=protected-access if trace_op: if is_a_fetched_tensor: tracing_ops.append(trace_op) continue # Add it to all consumers, as some consumers may not be executed if # they are in a control flow. for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access # pylint: disable=protected-access graph._set_control_flow_context(self._outmost_context) # pylint: enable=protected-access if tracing_ops: # If we are tracing a fetched tensor, their dependency is stored in # tracing_ops. processed_t_fetches = control_flow_ops.tuple(processed_t_fetches, control_inputs=tracing_ops) if self._use_tensor_values_cache() or self._use_tensor_buffer(): if self._use_temp_cache(): # Create the temporary tf cache variable by concantanating all # statistics. self._cache_variables[_TT_SUMMARY_TAG] = array_ops.stack( self._temp_cache_var, axis=0, name='stack_all_op_signatures') if self._create_host_call(): self._prepare_host_call_fn(processed_t_fetches, op_fetches) if not on_tpu: write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] cache_write_op = write_cache(**caches_to_write) processed_t_fetches = control_flow_ops.tuple( processed_t_fetches, control_inputs=[cache_write_op]) del self._host_call_fn[_TT_HOSTCALL_KEY] else: processed_t_fetches = self._flush_tensor_values_cache( processed_t_fetches, op_fetches, on_tpu=on_tpu, tensor_trace_order=tensor_trace_order) # processed_t_fetches is a list at this point. Convert it to the same # format as given in tensor_fetches. return self._convert_fetches_to_input_format(tensor_fetches, processed_t_fetches) def trace_tpu(self, graph, tensor_fetches, op_fetches=None, num_replicas=None, num_replicas_per_host=None, num_hosts=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops executed on the TPU. tensor_fetches: a (list,tuple,or a single object) of tensor fetches returned by model_fn given to session.run. Function must be provided with as least one tensor to fetch. op_fetches: A list of op fetches returned by model_fn given to session.run. op_fetches and tensor_fetches are used to determine the nodes that will be executed. Can be None. num_replicas: number of replicas used on the TPU. num_replicas_per_host: number of replicas per TPU host. num_hosts: total number of TPU hosts. Returns: tensor_fetches: an exact copy of tensor_fetches that has additional dependencies. Raises: RuntimeError: If num_replicas_per_host > 8. RuntimeError: If tensor_fetches is None or empty. """ if graph in TensorTracer._traced_graphs: logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 'multiple calls.') return tensor_fetches else: TensorTracer._traced_graphs.add(graph) # Reset the parameters in case parameters are changed. self._parameters = tensor_tracer_flags.TTParameters() self._tt_config.device_type = _DEVICE_TYPE_TPU self._tt_config.num_replicas = num_replicas self._tt_config.num_replicas_per_host = num_replicas_per_host self._tt_config.num_hosts = num_hosts if self._tt_config.num_replicas is not None: if self._tt_config.num_replicas_per_host is None: self._tt_config.num_replicas_per_host = 8 if self._tt_config.num_hosts is None: self._tt_config.num_hosts = ( num_replicas // self._tt_config.num_replicas_per_host + (num_replicas % self._tt_config.num_replicas_per_host > 0)) if self._parameters.graph_dump_path: graph_io.write_graph(graph, self._parameters.graph_dump_path, 'graph_before_tt.pbtxt') with graph.as_default(): self._add_replica_id_to_graph() tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, on_tpu=True) if self._parameters.graph_dump_path: graph_io.write_graph(graph, self._parameters.graph_dump_path, 'graph_after_tt.pbtxt') return tensor_fetches def trace_cpu(self, graph, tensor_fetches, op_fetches=None): """Traces the tensors generated by CPU Ops in a TF graph. Args: graph: the graph of Ops executed on the CPU. tensor_fetches: a (list,tuple,or a single object) of tensor fetches returned by model_fn given to session.run. Function must be provided with as least one tensor to fetch. op_fetches: A list of op fetches returned by model_fn given to session.run. op_fetches and tensor_fetches are used to determine the nodes that will be executed. Can be None. Returns: tensor_fetches: an exact copy of tensor_fetches that has additional dependencies. Raises: RuntimeError: If tensor_fetches is None or empty. """ if graph in TensorTracer._traced_graphs: logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 'multiple calls.') return tensor_fetches else: TensorTracer._traced_graphs.add(graph) # Reset the parameters in case parameters are changed. self._parameters = tensor_tracer_flags.TTParameters() self._tt_config.device_type = _DEVICE_TYPE_CPU self._tt_config.num_replicas = 1 self._tt_config.num_replicas_per_host = 1 self._tt_config.num_hosts = 1 self._replica_id = 0 if self._parameters.graph_dump_path: graph_io.write_graph(graph, self._parameters.graph_dump_path, 'graph_before_tt.pbtxt') with graph.as_default(): tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, on_tpu=False) if self._parameters.graph_dump_path: graph_io.write_graph(graph, self._parameters.graph_dump_path, 'graph_after_tt.pbtxt') return tensor_fetches