# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """xla is an experimental library that provides XLA support APIs.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.jit.ops import xla_ops from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.distribute import summary_op_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export _XLA_COMPILE_ATTR = '_xla_compile_id' _MAX_WARNING_LINES = 5 # Operations that indicate some error in the users graph. For example, XLA # computation should not have any Placeholder op. _DENYLISTED_OPS = set([ 'Placeholder', ]) # XLA doesn't currently support reading of intermediate tensors, thus some ops # are not supported. _UNSUPPORTED_OPS = set([ 'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary', 'MergeSummary', 'Print', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2', ]) @tf_export('xla.experimental.compile') @deprecated( None, 'xla.experimental.compile is deprecated. Consider using ' 'tf.function(experimental_compile=True)', warn_once=True) def compile(computation, inputs=None): # pylint: disable=redefined-builtin """Builds an operator that compiles and runs `computation` with XLA. NOTE: In eager mode, `computation` will have `@tf.function` semantics. Args: computation: A Python function that builds a computation to apply to the input. If the function takes n inputs, 'inputs' should be a list of n tensors. `computation` may return a list of operations and tensors. Tensors must come before operations in the returned list. The return value of `compile` is a list of tensors corresponding to the tensors from the output of `computation`. All `Operation`s returned from `computation` will be executed when evaluating any of the returned output tensors. inputs: A list of inputs or `None` (equivalent to an empty list). Each input can be a nested structure containing values that are convertible to tensors. Note that passing an N-dimension list of compatible values will result in a N-dimension list of scalar tensors rather than a single Rank-N tensors. If you need different behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. Returns: Same data structure as if computation(*inputs) is called directly with some exceptions for correctness. Exceptions include: 1) None output: a NoOp would be returned which control-depends on computation. 2) Single value output: A tuple containing the value would be returned. 3) Operation-only outputs: a NoOp would be returned which control-depends on computation. TODO(b/121383831): Investigate into removing these special cases. Raises: RuntimeError: if called when eager execution is enabled. Known issues: When a tf.random operation is built with XLA, the implementation doesn't pass the user provided seed to the XLA compiler. As such, the XLA compiler generates a random number and uses it as a seed when compiling the operation. This implementation causes a violation of the Tensorflow defined semantics in two aspects. First, changing the value of the user defined seed doesn't change the numbers generated by the operation. Second, when a seed is not specified, running the program multiple times will generate the same numbers. """ if context.executing_eagerly(): @def_function.function def xla_compile_wrapper(): return _compile_internal(computation, inputs) return xla_compile_wrapper() return _compile_internal(computation, inputs) class XLACompileContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside an XLA computation cluster. THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. The primary role of `XLACompileContext` is to mark operators inside a xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is a unique name. `ControlFlowContext` is used to perform the annotation since it integrates with Tensorflow constructs like ResourceVariables. For example, if a `ResourceVariable` is constructed inside a xla.compile() block, the `ResourceVariable` implementation can use `with ops.control_dependencies(None)` to build the variable's definition outside the compiled computation. """ def __init__(self, name, pivot): """Builds a new XLACompileContext. Args: name: a unique name for the context, used to populate the `_xla_compile_id` attribute. pivot: a pivot node. Nodes in the XLACompileContext that do not have any inputs will have a control dependency on the pivot node. This ensures that nodes are correctly included in any enclosing control flow contexts. """ super(XLACompileContext, self).__init__() self._name = name self._name_as_bytes = compat.as_bytes(name) self._unsupported_ops = [] self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: op_str = '\n'.join([ ' %s (%s)' % (op.type, op.name) for op in self._unsupported_ops[:_MAX_WARNING_LINES] ]) logging.warning('%d unsupported operations found: \n%s', len(self._unsupported_ops), op_str) if len(self._unsupported_ops) > _MAX_WARNING_LINES: logging.warning('... and %d more', len(self._unsupported_ops) - _MAX_WARNING_LINES) def _RemoveExternalControlEdges(self, op): """Remove any external control dependency on this op.""" internal_control_inputs = [] external_control_inputs = [] for x in op.control_inputs: # pylint: disable=protected-access is_internal_op = False ctxt = x._get_control_flow_context() while ctxt is not None: if ctxt == self: is_internal_op = True break ctxt = ctxt._outer_context if is_internal_op: internal_control_inputs.append(x) else: external_control_inputs.append(x) # pylint: enable=protected-access # pylint: disable=protected-access op._remove_all_control_inputs() op._add_control_inputs(internal_control_inputs) # pylint: enable=protected-access return internal_control_inputs, external_control_inputs def AddOp(self, op): """Create op in XLACompileContext and notifies outer context recursively.""" # pylint: disable=protected-access if op.type in _DENYLISTED_OPS: logging.error( 'Operation of type %s (%s) is not supported in XLA. Execution will ' 'fail if this op is used in the graph. ', op.type, op.name) # TODO(ycao): Automatically disable summaries instead of reporting them. if op.type in _UNSUPPORTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( 'Non-resource Variables are not supported inside XLA computations ' '(operator name: %s)' % op.name) if _XLA_COMPILE_ATTR in op.node_def.attr: raise ValueError('XLA compiled computations cannot be nested, (operator ' 'name: %s)' % op.name) op._set_attr( _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) # Remove any control edges from outer control flow contexts. These may cause # mismatched frame errors. An example is when one of op's inputs is # generated in a different While control flow context. (internal_control_inputs, external_control_inputs) = self._RemoveExternalControlEdges(op) if not op.inputs: # Add a control edge from the control pivot to this op. if not internal_control_inputs: # pylint: disable=protected-access op._add_control_input(self._pivot) # pylint: enable=protected-access else: for index in xrange(len(op.inputs)): x = op.inputs[index] real_x = self.AddValue(x) if real_x is not x: op._update_input(index, real_x) # pylint: disable=protected-access if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. with ops.control_dependencies(None): self.Enter() external_control_inputs = [ array_ops.identity(x.outputs[0]).op for x in external_control_inputs if x.outputs ] self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access # Mark op's outputs as seen by this context and any outer contexts. output_names = [x.name for x in op.outputs] context = self while context is not None: # pylint: disable=protected-access context._values.update(output_names) context = context._outer_context # pylint: enable=protected-access if self._outer_context: self._outer_context.AddInnerOp(op) def AddValue(self, val): """Add `val` to the current context and its outer context recursively.""" if val.name in self._values: # Use the real value if it comes from outer context. result = self._external_values.get(val.name) return val if result is None else result result = val self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) self._values.add(result.name) self._external_values[val.name] = result return result def AddInnerOp(self, op): self.AddOp(op) if self._outer_context: self._outer_context.AddInnerOp(op) @property def grad_state(self): # Define the gradient loop state associated with the XLACompileContext to # be None as the XLACompileContext does not get nested nor does the # grad_state outside the XLACompileContext affect the graph inside so the # grad_state should be as if this is the top-level gradient state. return None @property def back_prop(self): """Forwards to the enclosing while context, if any.""" if self.GetWhileContext(): return self.GetWhileContext().back_prop return False def _compile_internal(computation, inputs=None): """Builds graph operators that compiles and symbolically executes computation. Args: computation: A Python function that builds the computation to compile and execute. inputs: A list of inputs or `None` (equivalent to an empty list). Each input can be a nested structure containing values that are convertible to tensors. Note that passing an N-dimension list of compatible values will result in a N-dimension list of scalar tensors rather than a single Rank-N tensors. If you need different behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. Returns: Same data structure as if computation(*inputs) is called directly with some exceptions for correctness. Exceptions include: 1) None output 2) Single value output 3) Operation-only outputs Raises: ValueError: If any element in computation outputs is neither an operations or a value that can be converted to tensor. ValueError: If computation outputs is non-flat and contains any Operations. TypeError: If `inputs` is not a list or tuple. """ if inputs is None: inputs = [] if not isinstance(inputs, collections_abc.Sequence): raise TypeError('inputs must be a list') # Flatten inputs. flat_inputs = nest.flatten(inputs) # Converts inputs to Tensors. flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') context = XLACompileContext(name=cluster_name, pivot=pivot) try: context.Enter() # Add identity ops so even unused inputs are 'consumed' by the # computation. flat_inputs = [ array_ops.identity(x, name='input_{}'.format(i)) for i, x in enumerate(flat_inputs) ] # Re-pack flat_inputs in same structure as 'inputs'. computation_inputs = nest.pack_sequence_as( structure=inputs, flat_sequence=flat_inputs) # Only resource variables work inside an XLA computation, so turn on # resource variables for the computation. vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource vscope.set_use_resource(True) with _disable_summary_context(): outputs = computation(*computation_inputs) # Restore variable scope after computation. vscope.set_use_resource(saved_use_resource) outputs_is_flat = is_flat(outputs) if outputs_is_flat: output_tensors, control_deps = _postprocess_flat_outputs(outputs) else: output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() # When XLA computation returns only operations and no tensors, a NoOp # dependent on the operations in outputs is returned. Otherwise final # outputs would be empty and there is no way to trigger returned # operations. if not output_tensors: return control_flow_ops.group(control_deps, name='output_0') output_tensors = [ xla_ops.xla_cluster_output(o, name='output{}'.format(i)) for i, o in enumerate(output_tensors) ] with ops.control_dependencies(control_deps): # Wraps the outputs in identity operators that carries control # dependencies. output_tensors = [ array_ops.identity(o, name='output_%d' % i) for i, o in enumerate(output_tensors) ] # If `computation` returned non-flat output structure, pack output tensors # back into same structure. if not outputs_is_flat: output_tensors = nest.pack_sequence_as( structure=outputs, flat_sequence=output_tensors) return output_tensors def is_flat(outputs): """Checks if outputs is a flat structure. Following structures and values are considered flat: 1) None 2) A single object 3) A list or tuple of Tensors/Operations The only structures that this function understands are sequences, dictionaries and types defined using the attrs library. E.g. this means that if outputs contains a single user-defined Object, it is considered to be flat. Errors are raised later on if that Object cannot be converted to a Tensor. Args: outputs: Output from `computation` inside `xla.compile`. Returns: A boolean indicates whether outputs is flat. """ # If outputs is a list or tuple, check if it has any nested structure. If # there is, then outputs is non-flat. if isinstance(outputs, collections_abc.Sequence): for o in outputs: if (isinstance(o, collections_abc.Sequence) or isinstance(o, collections_abc.Mapping) or hasattr(o.__class__, '__attrs_attrs__')): return False # If outputs is a dict, it is non-flat. if isinstance(outputs, collections_abc.Mapping): return False # If outputs is from the attrs library, it is non-flat. if hasattr(outputs.__class__, '__attrs_attrs__'): return False # Getting here means either outputs itself is a single non-structured value # or it is a flat list of single non-structured values. return True def _postprocess_flat_outputs(outputs): """Validates flat outputs and adds back device assignments. Args: outputs: Output from `computation` inside `xla.compile`. Returns: Tensors and Operations extracted from outputs. """ # Following code segment is to preserve legacy behavior. Previously we only # supported flat outputs and thus for consistency it was nice to convert even # single element into a tuple. But now that we support arbitrary output # structure, this is no longer necessary. # TODO(b/121383831): Migrate all legacy use cases and delete this special # case. # If the computation returns `None`, make it an empty tuple. if outputs is None: outputs = tuple() # If the computation only returned one value, make it a tuple. if not isinstance(outputs, collections_abc.Sequence): outputs = (outputs,) # Append `no_op` here so that return value of this function always contains # at least one op that can trigger XlaLaunch node. outputs += (control_flow_ops.no_op(),) try: outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( 'XLA computation function return values must all either be Operations' ' or convertible to Tensors. Got error: "%s"' % str(e)) # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( 'XLA computation function must return zero or more Tensor values ' 'followed by zero or more Operations.') new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else ''): new_output_tensors.append(array_ops.identity(t)) return new_output_tensors, output_operations def _postprocess_non_flat_outputs(outputs): """Validates non-flat outputs and adds back device assignments. Args: outputs: Output from `computation` inside `xla.compile`. Returns: Tensors extracted from outputs and an empty list because Operations are not allowed in non-flat outputs.. """ # Convert all non-Operation outputs to Tensors. new_output_tensors = [] for o in nest.flatten(outputs): if isinstance(o, ops.Operation): raise ValueError( 'xla.compile does not support Operation as return value in non-flat ' 'output structure. You can set returned Operations as control ' 'dependencies of returned Tensors so Operations are triggered when ' 'Tensors are evaluated. Operation found: "%s"' % o.name) try: o = ops.convert_to_tensor(o) except Exception as e: raise ValueError( 'XLA computation function return values must all either be ' 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) # Makes sure even pass-through inputs/outputs are touched in compile # context by creating an Identity node inside compile context. with ops.device(o.device if o.device else ''): new_output_tensors.append(array_ops.identity(o)) return new_output_tensors, [] @contextlib.contextmanager def _disable_summary_context(): """Enters a context where all summary ops are skipped. Summaries are not yet supported in xla.compile(). So we provide this context manager that can skip creating summary ops. This is a temporary workaround due to XLA not supporting summary ops. Yields: None. """ original_skip_summary_func = summary_op_util.skip_summary summary_op_util.skip_summary = lambda: True try: yield finally: summary_op_util.skip_summary = original_skip_summary_func class _CapturedObject(object): """A placeholder to capture an object.""" def __init__(self): self._object = None def capture(self, o): if self._object: raise RuntimeError( 'InternalError: _CapturedObject can capture only once. Please file ' 'bug.') self._object = o def get(self): return self._object def _get_scaffold(captured_scaffold_fn): """Retrieves the Scaffold from `captured_scaffold_fn`.""" scaffold_fn = captured_scaffold_fn.get() if not scaffold_fn: return None scaffold = scaffold_fn() if scaffold is None: raise ValueError( 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') return scaffold def check_function_argument_count(func, input_arity, infeed_queue): """Validate the number of input arguments to an XLA function. Args: func: the Python function that will be called to generate the body of an XLA computation graph. input_arity: the number of explicit arguments supplied by the caller. infeed_queue: if not None, the infeed queue that will supply additional arguments to the function. Returns: None if function can be called with the supplied number of arguments, or an error string if it cannot. """ def format_error(complaint, quantity): return '%s %d argument%s' % (complaint, quantity, '' if quantity == 1 else 's') num_args_supplied = input_arity if infeed_queue is not None: num_args_supplied += infeed_queue.number_of_tuple_elements arg_spec = tf_inspect.getargspec(func) num_func_args = len(arg_spec.args) if arg_spec.defaults is None: num_func_defaults = 0 else: num_func_defaults = len(arg_spec.defaults) min_func_args = num_func_args - num_func_defaults if num_args_supplied < min_func_args: # The required number of arguments is not enough to call the function. if num_func_defaults == 0 and arg_spec.varargs is None: return format_error('exactly', num_func_args) else: return format_error('at least', min_func_args) if arg_spec.varargs is None and num_args_supplied > num_func_args: # The required number of arguments is too many to call the function. if num_func_defaults == 0: return format_error('exactly', num_func_args) else: return format_error('at most', num_func_args) # Reaching here means either # 1) There are varargs, func can accept any number of arguments greater than # the minimum. # 2) Number of supplied arguments falls in range of acceptable argument count # of func. return None