# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # pylint: disable=line-too-long """Library for running a computation across multiple devices. The intent of this library is that you can write an algorithm in a stylized way and it will be usable with a variety of different `tf.distribute.Strategy` implementations. Each descendant will implement a different strategy for distributing the algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and other library classes that need special treatment to run in a distributed setting, so that most users' model definition code can run unchanged. The `tf.distribute.Strategy` API works the same way with eager and graph execution. *Guides* * [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) * [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) *Tutorials* * [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) The tutorials cover how to use `tf.distribute.Strategy` to do distributed training with native Keras APIs, custom training loops, and Esitmator APIs. They also cover how to save/load model when using `tf.distribute.Strategy`. *Glossary* * _Data parallelism_ is where we run multiple copies of the model on different slices of the input data. This is in contrast to _model parallelism_ where we divide up a single copy of a model across multiple devices. Note: we only support data parallelism for now, but hope to add support for model parallelism in the future. * A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple devices on a single machine, or be connected to devices on multiple machines. Devices used to run computations are called _worker devices_. Devices used to store variables are _parameter devices_. For some strategies, such as `tf.distribute.MirroredStrategy`, the worker and parameter devices will be the same (see mirrored variables below). For others they will be different. For example, `tf.distribute.experimental.CentralStorageStrategy` puts the variables on a single device (which may be a worker device or may be the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the variables on separate machines called _parameter servers_ (see below). * A _replica_ is one copy of the model, running on one slice of the input data. Right now each replica is executed on its own worker device, but once we add support for model parallelism a replica may span multiple worker devices. * A _host_ is the CPU device on a machine with worker devices, typically used for running input pipelines. * A _worker_ is defined to be the physical machine(s) containing the physical devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A worker may contain one or more replicas, but contains at least one replica. Typically one worker will correspond to one machine, but in the case of very large models with model parallelism, one worker may span multiple machines. We typically run one input pipeline per worker, feeding all the replicas on that worker. * _Synchronous_, or more commonly _sync_, training is where the updates from each replica are aggregated together before updating the model variables. This is in contrast to _asynchronous_, or _async_ training, where each replica updates the model variables independently. You may also have replicas partitioned into groups which are in sync within each group but async between groups. * _Parameter servers_: These are machines that hold a single copy of parameters/variables, used by some strategies (right now just `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want to operate on a variable retrieve it at the beginning of a step and send an update to be applied at the end of the step. These can in priniciple support either sync or async training, but right now we only have support for async training with parameter servers. Compare to `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables on a single device on the same machine (and does sync training), and `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices (see below). * _Replica context_ vs. _Cross-replica context_ vs _Update context_ A _replica context_ applies when you execute the computation function that was called with `strategy.run`. Conceptually, you're in replica context when executing the computation function that is being replicated. An _update context_ is entered in a `tf.distribute.StrategyExtended.update` call. An _cross-replica context_ is entered when you enter a `strategy.scope`. This is useful for calling `tf.distribute.Strategy` methods which operate across the replicas (like `reduce_to()`). By default you start in a _replica context_ (the "default single _replica context_") and then some methods can switch you back and forth. * _Distributed value_: Distributed value is represented by the base class `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful to represent values on multiple devices, and it contains a map from replica id to values. Two representative kinds of `tf.distribute.DistributedValues` are "PerReplica" and "Mirrored" values. "PerReplica" values exist on the worker devices, with a different value for each replica. They are produced by iterating through a distributed dataset returned by `tf.distribute.Strategy.experimental_distribute_dataset` and `tf.distribute.Strategy.distribute_datasets_from_function`. They are also the typical result returned by `tf.distribute.Strategy.run`. "Mirrored" values are like "PerReplica" values, except we know that the value on all replicas are the same. We can safely read a "Mirrored" value in a cross-replica context by using the value on any replica. * _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple replicas, like `strategy.run(fn, args=[w])` with an argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return values from `fn()`, which leads to one common object if the returned values are the same object from every replica, or a `DistributedValues` object otherwise. * _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating multiple values into one value, like "sum" or "mean". If a strategy is doing sync training, we will perform a reduction on the gradients to a parameter from all replicas before applying the update. _All-reduce_ is an algorithm for performing a reduction on values from multiple devices and making the result available on all of those devices. * _Mirrored variables_: These are variables that are created on multiple devices, where we keep the variables in sync by applying the same updates to every copy. Mirrored variables are created with `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. Normally they are only used in synchronous training. * _SyncOnRead variables_ _SyncOnRead variables_ are created by `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and they are created on multiple devices. In replica context, each component variable on the local replica can perform reads and writes without synchronization with each other. When the _SyncOnRead variable_ is read in cross-replica context, the values from component variables are aggregated and returned. _SyncOnRead variables_ bring a lot of custom configuration difficulty to the underlying logic, so we do not encourage users to instantiate and use _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead variables_ for use cases such as batch norm and metrics. For performance reasons, we often don't need to keep these statistics in sync every step and they can be accumulated on each replica independently. The only time we want to sync them is reporting or checkpointing, which typically happens in cross-replica context. _SyncOnRead variables_ are also often used by advanced users who want to control when variable values are aggregated. For example, users sometimes want to maintain gradients independently on each replica for a couple of steps without aggregation. * _Distribute-aware layers_ Layers are generally called in a replica context, except when defining a Keras functional model. `tf.distribute.in_cross_replica_context` will let you determine which case you are in. If in a replica context, the `tf.distribute.get_replica_context` function will return the default replica context outside a strategy scope, `None` within a strategy scope, and a `tf.distribute.ReplicaContext` object inside a strategy scope and within a `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an `all_reduce` method for aggregating across all replicas. Note that we provide a default version of `tf.distribute.Strategy` that is used when no other strategy is in scope, that provides the same API with reasonable default behavior. """ # pylint: enable=line-too-long from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import copy import enum # pylint: disable=g-bad-import-order import threading import weakref import six from tensorflow.python.autograph.core import ag_ctx as autograph_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context as eager_context from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import loss_reduction from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_contextlib from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls # ------------------------------------------------------------------------------ # Context tracking whether in a strategy.update() or .update_non_slot() call. _update_replica_id = threading.local() def get_update_replica_id(): """Get the current device if in a `tf.distribute.Strategy.update()` call.""" try: return _update_replica_id.current except AttributeError: return None class UpdateContext(object): """Context manager when you are in `update()` or `update_non_slot()`.""" __slots__ = ["_replica_id", "_old_replica_id"] def __init__(self, replica_id): self._replica_id = replica_id self._old_replica_id = None def __enter__(self): self._old_replica_id = get_update_replica_id() _update_replica_id.current = self._replica_id def __exit__(self, exception_type, exception_value, traceback): del exception_type, exception_value, traceback _update_replica_id.current = self._old_replica_id # ------------------------------------------------------------------------------ # Public utility functions. @tf_export(v1=["distribute.get_loss_reduction"]) def get_loss_reduction(): """`tf.distribute.ReduceOp` corresponding to the last loss reduction. This is used to decide whether loss should be scaled in optimizer (used only for estimator + v1 optimizer use case). Returns: `tf.distribute.ReduceOp` corresponding to the last loss reduction for estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. """ if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access # If we are not in Estimator context then return 'SUM'. We do not need to # scale loss in the optimizer. return reduce_util.ReduceOp.SUM last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access if (last_reduction == losses_impl.Reduction.SUM or last_reduction == loss_reduction.ReductionV2.SUM): return reduce_util.ReduceOp.SUM return reduce_util.ReduceOp.MEAN # ------------------------------------------------------------------------------ # Internal API for validating the current thread mode def _require_cross_replica_or_default_context_extended(extended, error_message=None): """Verify in cross-replica context.""" context = _get_per_thread_mode() cross_replica = context.cross_replica_context if cross_replica is not None and cross_replica.extended is extended: return if context is _get_default_replica_mode(): return strategy = extended._container_strategy() # pylint: disable=protected-access # We have an error to report, figure out the right message. if context.strategy is not strategy: _wrong_strategy_scope(strategy, context) assert cross_replica is None if not error_message: error_message = ("Method requires being in cross-replica context, use " "get_replica_context().merge_call()") raise RuntimeError(error_message) def _wrong_strategy_scope(strategy, context): # Figure out the right error message. if not distribution_strategy_context.has_strategy(): raise RuntimeError( 'Need to be inside "with strategy.scope()" for %s' % (strategy,)) else: raise RuntimeError( "Mixing different tf.distribute.Strategy objects: %s is not %s" % (context.strategy, strategy)) def require_replica_context(replica_ctx): """Verify in `replica_ctx` replica context.""" context = _get_per_thread_mode() if context.replica_context is replica_ctx: return # We have an error to report, figure out the right message. if context.replica_context is None: raise RuntimeError("Need to be inside `call_for_each_replica()`") if context.strategy is replica_ctx.strategy: # Two different ReplicaContexts with the same tf.distribute.Strategy. raise RuntimeError("Mismatching ReplicaContext.") raise RuntimeError( "Mismatching tf.distribute.Strategy objects: %s is not %s." % (context.strategy, replica_ctx.strategy)) def _require_strategy_scope_strategy(strategy): """Verify in a `strategy.scope()` in this thread.""" context = _get_per_thread_mode() if context.strategy is strategy: return _wrong_strategy_scope(strategy, context) def _require_strategy_scope_extended(extended): """Verify in a `distribution_strategy.scope()` in this thread.""" context = _get_per_thread_mode() if context.strategy.extended is extended: return # Report error. strategy = extended._container_strategy() # pylint: disable=protected-access _wrong_strategy_scope(strategy, context) # ------------------------------------------------------------------------------ # Internal context managers used to implement the DistributionStrategy # base class class _CurrentDistributionContext(object): """Context manager setting the current `tf.distribute.Strategy`. Also: overrides the variable creator and optionally the current device. """ def __init__(self, strategy, var_creator_scope, var_scope=None, default_device=None): self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope if default_device: self._device_scope = ops.device(default_device) else: self._device_scope = None self._same_scope_again_count = 0 def __enter__(self): # Allow this scope to be entered if this strategy is already in scope. if distribution_strategy_context.has_strategy(): _require_cross_replica_or_default_context_extended( self._context.strategy.extended) self._same_scope_again_count += 1 else: _push_per_thread_mode(self._context) if self._var_scope: self._var_scope.__enter__() self._var_creator_scope.__enter__() if self._device_scope: self._device_scope.__enter__() return self._context.strategy def __exit__(self, exception_type, exception_value, traceback): if self._same_scope_again_count > 0: self._same_scope_again_count -= 1 return if self._device_scope: try: self._device_scope.__exit__(exception_type, exception_value, traceback) except RuntimeError as e: six.raise_from( RuntimeError("Device scope nesting error: move call to " "tf.distribute.set_strategy() out of `with` scope."), e) try: self._var_creator_scope.__exit__( exception_type, exception_value, traceback) except RuntimeError as e: six.raise_from( RuntimeError("Variable creator scope nesting error: move call to " "tf.distribute.set_strategy() out of `with` scope."), e) if self._var_scope: try: self._var_scope.__exit__(exception_type, exception_value, traceback) except RuntimeError as e: six.raise_from( RuntimeError("Variable scope nesting error: move call to " "tf.distribute.set_strategy() out of `with` scope."), e) _pop_per_thread_mode() # TODO(yuefengz): add more replication modes. @tf_export("distribute.InputReplicationMode") class InputReplicationMode(enum.Enum): """Replication mode for input function. * `PER_WORKER`: The input function will be called on each worker independently, creating as many input pipelines as number of workers. Replicas will dequeue from the local Dataset on their worker. `tf.distribute.Strategy` doesn't manage any state sharing between such separate input pipelines. * `PER_REPLICA`: The input function will be called on each replica seperately. `tf.distribute.Strategy` doesn't manage any state sharing between such separate input pipelines. """ PER_WORKER = "PER_WORKER" PER_REPLICA = "PER_REPLICA" @tf_export("distribute.InputContext") class InputContext(object): """A class wrapping information needed by an input function. This is a context class that is passed to the user's input function and contains information about the compute replicas and input pipelines. The number of compute replicas (in sync training) helps compute the local batch size from the desired global batch size for each replica. The input pipeline information can be used to return a different subset of the input in each replica (for e.g. shard the input pipeline, use a different input source etc). """ __slots__ = [ "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" ] def __init__(self, num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=1): """Initializes an InputContext object. Args: num_input_pipelines: the number of input pipelines in a cluster. input_pipeline_id: the current input pipeline id, should be an int in [0,`num_input_pipelines`). num_replicas_in_sync: the number of replicas that are in sync. """ self._num_input_pipelines = num_input_pipelines self._input_pipeline_id = input_pipeline_id self._num_replicas_in_sync = num_replicas_in_sync @property def num_replicas_in_sync(self): """Returns the number of compute replicas in sync.""" return self._num_replicas_in_sync @property def input_pipeline_id(self): """Returns the input pipeline ID.""" return self._input_pipeline_id @property def num_input_pipelines(self): """Returns the number of input pipelines.""" return self._num_input_pipelines def get_per_replica_batch_size(self, global_batch_size): """Returns the per-replica batch size. Args: global_batch_size: the global batch size which should be divisible by `num_replicas_in_sync`. Returns: the per-replica batch size. Raises: ValueError: if `global_batch_size` not divisible by `num_replicas_in_sync`. """ if global_batch_size % self._num_replicas_in_sync != 0: raise ValueError("The `global_batch_size` %r is not divisible by " "`num_replicas_in_sync` %r " % (global_batch_size, self._num_replicas_in_sync)) return global_batch_size // self._num_replicas_in_sync def __str__(self): return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( self.input_pipeline_id, self.num_input_pipelines) @tf_export("distribute.experimental.ValueContext", v1=[]) class ValueContext(object): """A class wrapping information needed by a distribute function. This is a context class that is passed to the `value_fn` in `strategy.experimental_distribute_values_from_function` and contains information about the compute replicas. The `num_replicas_in_sync` and `replica_id` can be used to customize the value on each replica. Example usage: 1. Directly constructed. >>> def value_fn(context): ... return context.replica_id_in_sync_group/context.num_replicas_in_sync >>> context = tf.distribute.experimental.ValueContext( ... replica_id_in_sync_group=2, num_replicas_in_sync=4) >>> per_replica_value = value_fn(context) >>> per_replica_value 0.5 2. Passed in by `experimental_distribute_values_from_function`. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def value_fn(value_context): ... return value_context.num_replicas_in_sync >>> distributed_values = ( ... strategy.experimental_distribute_values_from_function( ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result (2, 2) """ __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] def __init__(self, replica_id_in_sync_group=0, num_replicas_in_sync=1): """Initializes an ValueContext object. Args: replica_id_in_sync_group: the current replica_id, should be an int in [0,`num_replicas_in_sync`). num_replicas_in_sync: the number of replicas that are in sync. """ self._replica_id_in_sync_group = replica_id_in_sync_group self._num_replicas_in_sync = num_replicas_in_sync @property def num_replicas_in_sync(self): """Returns the number of compute replicas in sync.""" return self._num_replicas_in_sync @property def replica_id_in_sync_group(self): """Returns the replica ID.""" return self._replica_id_in_sync_group def __str__(self): return (("tf.distribute.ValueContext(replica id {}, " " total replicas in sync: ""{})") .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) @tf_export("distribute.RunOptions") class RunOptions( collections.namedtuple("RunOptions", [ "experimental_enable_dynamic_batch_size", "experimental_bucketizing_dynamic_shape", ])): """Run options for `strategy.run`. This can be used to hold some strategy specific configs. Attributes: experimental_enable_dynamic_batch_size: Boolean. Only applies to TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic padder to support dynamic batch size for the inputs. Otherwise only static shape inputs are allowed. experimental_bucketizing_dynamic_shape: Boolean. Only applies to TPUStrategy. Default to False. If True, TPUStrategy will automatic bucketize inputs passed into `run` if the input shape is dynamic. This is a performance optimization to reduce XLA recompilation, which should not have impact on correctness. """ def __new__(cls, experimental_enable_dynamic_batch_size=True, experimental_bucketizing_dynamic_shape=False): return super(RunOptions, cls).__new__(cls, experimental_enable_dynamic_batch_size, experimental_bucketizing_dynamic_shape) @tf_export("distribute.InputOptions", v1=[]) class InputOptions( collections.namedtuple("InputOptions", [ "experimental_prefetch_to_device", "experimental_replication_mode", "experimental_place_dataset_on_device", ])): """Run options for `experimental_distribute_dataset(s_from_function)`. This can be used to hold some strategy specific configs. ```python # Setup TPUStrategy resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) dataset = tf.data.Dataset.range(16) distributed_dataset_on_host = ( strategy.experimental_distribute_dataset( dataset, tf.distribute.InputOptions( experimental_replication_mode= experimental_replication_mode.PER_WORKER, experimental_place_dataset_on_device=False))) ``` Attributes: experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset elements will be prefetched to accelerator device memory. When False, dataset elements are prefetched to host device memory. Must be False when using TPUEmbedding API. experimental_prefetch_to_device can only be used with experimental_replication_mode=PER_WORKER experimental_replication_mode: Replication mode for the input function. Currently, the InputReplicationMode.PER_REPLICA is only supported with tf.distribute.MirroredStrategy. experimental_distribute_datasets_from_function. The default value is InputReplicationMode.PER_WORKER. experimental_place_dataset_on_device: Boolean. Default to False. When True, dataset will be placed on the device, otherwise it will remain on the host. experimental_place_dataset_on_device=True can only be used with experimental_replication_mode=PER_REPLICA """ def __new__(cls, experimental_prefetch_to_device=True, experimental_replication_mode=InputReplicationMode.PER_WORKER, experimental_place_dataset_on_device=False): return super(InputOptions, cls).__new__(cls, experimental_prefetch_to_device, experimental_replication_mode, experimental_place_dataset_on_device) # ------------------------------------------------------------------------------ # Base classes for all distribution strategies. # Base class for v1 Strategy and v2 Strategy classes. For API's specific to # v1/v2 Strategy, add to implementing classes of StrategyBase. # pylint: disable=line-too-long class StrategyBase(object): """A state & compute distribution policy on a list of devices. See [the guide](https://www.tensorflow.org/guide/distributed_training) for overview and examples. See `tf.distribute.StrategyExtended` and [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) for a glossory of concepts mentioned on this page such as "per-replica", _replica_, and _reduce_. In short: * To use it with Keras `compile`/`fit`, [please read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). * You may pass descendant of `tf.distribute.Strategy` to `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` should distribute its computation. See [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). * Otherwise, use `tf.distribute.Strategy.scope` to specify that a strategy should be used when building an executing your model. (This puts you in the "cross-replica context" for this strategy, which means the strategy is put in control of things like variable placement.) * If you are writing a custom training loop, you will need to call a few more methods, [see the guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): * Start by creating a `tf.data.Dataset` normally. * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert a `tf.data.Dataset` to something that produces "per-replica" values. If you want to manually specify how the dataset should be partitioned across replicas, use `tf.distribute.Strategy.distribute_datasets_from_function` instead. * Use `tf.distribute.Strategy.run` to run a function once per replica, taking values that may be "per-replica" (e.g. from a `tf.distribute.DistributedDataset` object) and returning "per-replica" values. This function is executed in "replica context", which means each operation is performed separately on each replica. * Finally use a method (such as `tf.distribute.Strategy.reduce`) to convert the resulting "per-replica" values into ordinary `Tensor`s. A custom training loop can be as simple as: ``` with my_strategy.scope(): @tf.function def distribute_train_epoch(dataset): def replica_fn(input): # process input and return result return result total_result = 0 for x in dataset: per_replica_result = my_strategy.run(replica_fn, args=(x,)) total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_result, axis=None) return total_result dist_dataset = my_strategy.experimental_distribute_dataset(dataset) for _ in range(EPOCHS): train_result = distribute_train_epoch(dist_dataset) ``` This takes an ordinary `dataset` and `replica_fn` and runs it distributed using a particular `tf.distribute.Strategy` named `my_strategy` above. Any variables created in `replica_fn` are created using `my_strategy`'s policy, and library functions called by `replica_fn` can use the `get_replica_context()` API to implement distributed-specific behavior. You can use the `reduce` API to aggregate results across replicas and use this as a return value from one iteration over a `tf.distribute.DistributedDataset`. Or you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to accumulate metrics across steps in a given epoch. See the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) for a more detailed example. Note: `tf.distribute.Strategy` currently does not support TensorFlow's partitioned variables (where a single variable is split across multiple devices) at this time. """ # pylint: enable=line-too-long # TODO(josh11b): Partitioned computations, state; sharding # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling def __init__(self, extended): self._extended = extended # Flag that is used to indicate whether distribution strategy is used with # Estimator. This is required for backward compatibility of loss scaling # when using v1 optimizer with estimator. self._scale_loss_for_estimator = False if not hasattr(extended, "_retrace_functions_for_each_device"): # pylint: disable=protected-access # `extended._retrace_functions_for_each_device` dictates # whether the same function will be retraced when it is called on # different devices. try: extended._retrace_functions_for_each_device = ( len(extended.worker_devices) > 1) distribution_strategy_replica_gauge.get_cell("num_replicas").set( self.num_replicas_in_sync) except: # pylint: disable=bare-except # Default for the case where extended.worker_devices can't return # a sensible value. extended._retrace_functions_for_each_device = True # Below are the dicts of axis(int) -> `tf.function`. self._mean_reduce_helper_fns = {} self._reduce_sum_fns = {} @property def extended(self): """`tf.distribute.StrategyExtended` with additional methods.""" return self._extended @tf_contextlib.contextmanager def _scale_loss_for_estimator_enabled(self): """Scope which sets a flag used for scaling losses in optimizer. Yields: `_scale_loss_for_estimator_enabled` is a context manager with a side effect, but doesn't return a value. """ self._scale_loss_for_estimator = True try: yield finally: self._scale_loss_for_estimator = False # pylint: disable=line-too-long def scope(self): """Context manager to make the strategy current and distribute variables. This method returns a context manager, and is used as follows: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> # Variable created inside scope: >>> with strategy.scope(): ... mirrored_variable = tf.Variable(1.) >>> mirrored_variable MirroredVariable:{ 0: , 1: } >>> # Variable created outside scope: >>> regular_variable = tf.Variable(1.) >>> regular_variable _What happens when Strategy.scope is entered?_ * `strategy` is installed in the global context as the "current" strategy. Inside this scope, `tf.distribute.get_strategy()` will now return this strategy. Outside this scope, it returns the default no-op strategy. * Entering the scope also enters the "cross-replica context". See `tf.distribute.StrategyExtended` for an explanation on cross-replica and replica contexts. * Variable creation inside `scope` is intercepted by the strategy. Each strategy defines how it wants to affect the variable creation. Sync strategies like `MirroredStrategy`, `TPUStrategy` and `MultiWorkerMiroredStrategy` create variables replicated on each replica, whereas `ParameterServerStrategy` creates variables on the parameter servers. This is done using a custom `tf.variable_creator_scope`. * In some strategies, a default device scope may also be entered: in `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is entered on each worker. Note: Entering a scope does not automatically distribute a computation, except in the case of high level training framework like keras `model.fit`. If you're not using `model.fit`, you need to use `strategy.run` API to explicitly distribute that computation. See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). _What should be in scope and what should be outside?_ There are a number of requirements on what needs to happen inside the scope. However, in places where we have information about which strategy is in use, we often enter the scope for the user, so they don't have to do it explicitly (i.e. calling those either inside or outside the scope is OK). * Anything that creates variables that should be distributed variables must be in `strategy.scope`. This can be either by directly putting it in scope, or relying on another API like `strategy.run` or `model.fit` to enter it for you. Any variable that is created outside scope will not be distributed and may have performance implications. Common things that create variables in TF: models, optimizers, metrics. These should always be created inside the scope. Another source of variable creation can be a checkpoint restore - when variables are created lazily. Note that any variable created inside a strategy captures the strategy information. So reading and writing to these variables outside the `strategy.scope` can also work seamlessly, without the user having to enter the scope. * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which require to be in a strategy's scope, enter the scope for you automatically, which means when using those APIs you don't need to enter the scope yourself. * When a `tf.keras.Model` is created inside a `strategy.scope`, we capture this information. When high level training frameworks methods such as `model.compile`, `model.fit` etc are then called on this model, we automatically enter the scope, as well as use this strategy to distribute the training etc. See detailed example in [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). Note that simply calling the `model(..)` is not impacted - only high level training framework APIs are. `model.compile`, `model.fit`, `model.evaluate`, `model.predict` and `model.save` can all be called inside or outside the scope. * The following can be either inside or outside the scope: * Creating the input datasets * Defining `tf.function`s that represent your training step * Saving APIs such as `tf.saved_model.save`. Loading creates variables, so that should go inside the scope if you want to train the model in a distributed way. * Checkpoint saving. As mentioned above - `checkpoint.restore` may sometimes need to be inside scope if it creates variables. Returns: A context manager. """ return self._extended._scope(self) # pylint: disable=protected-access # pylint: enable=line-too-long @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` def colocate_vars_with(self, colocate_with_variable): """DEPRECATED: use extended.colocate_vars_with() instead.""" return self._extended.colocate_vars_with(colocate_with_variable) @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only def make_dataset_iterator(self, dataset): """DEPRECATED TF 1.x ONLY.""" return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only def make_input_fn_iterator(self, input_fn, replication_mode=InputReplicationMode.PER_WORKER): """DEPRECATED TF 1.x ONLY.""" if replication_mode != InputReplicationMode.PER_WORKER: raise ValueError( "Input replication mode not supported: %r" % replication_mode) with self.scope(): return self.extended._make_input_fn_iterator( # pylint: disable=protected-access input_fn, replication_mode=replication_mode) @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only def experimental_run(self, fn, input_iterator=None): """DEPRECATED TF 1.x ONLY.""" with self.scope(): args = (input_iterator.get_next(),) if input_iterator is not None else () return self.run(fn, args=args) def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=line-too-long """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. The returned `tf.distribute.DistributedDataset` can be iterated over similar to regular datasets. NOTE: The user cannot add any more transformations to a `tf.distribute.DistributedDataset`. You can only create an iterator or examine the `tf.TypeSpec` of the data generated by it. See API docs of `tf.distribute.DistributedDataset` to learn more. The following is an example: >>> global_batch_size = 2 >>> # Passing the devices is optional. ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) >>> # Create a dataset ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) >>> # Distribute that dataset ... dist_dataset = strategy.experimental_distribute_dataset(dataset) >>> @tf.function ... def replica_fn(input): ... return input*2 >>> result = [] >>> # Iterate over the `tf.distribute.DistributedDataset` ... for x in dist_dataset: ... # process dataset elements ... result.append(strategy.run(replica_fn, args=(x,))) >>> print(result) [PerReplica:{ 0: , 1: }, PerReplica:{ 0: , 1: }] Three key actions happending under the hood of this method are batching, sharding, and prefetching. In the code snippet above, `dataset` is batched by `global_batch_size`, and calling `experimental_distribute_dataset` on it rebatches `dataset` to a new batch size that is equal to the global batch size divided by the number of replicas in sync. We iterate through it using a Pythonic for loop. `x` is a `tf.distribute.DistributedValues` containing data for all replicas, and each replica gets data of the new batch size. `tf.distribute.Strategy.run` will take care of feeding the right per-replica data in `x` to the right `replica_fn` executed on each replica. Sharding contains autosharding across multiple workers and within every worker. First, in multi-worker distributed training (i.e. when you use `tf.distribute.experimental.MultiWorkerMirroredStrategy` or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of workers means that each worker is assigned a subset of the entire dataset (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to ensure that at each step, a global batch size of non-overlapping dataset elements will be processed by each worker. Autosharding has a couple of different options that can be specified using `tf.data.experimental.DistributeOptions`. Then, sharding within each worker means the method will split the data among all the worker devices (if more than one a present). This will happen regardless of multi-worker autosharding. Note: for autosharding across multiple workers, the default mode is `tf.data.experimental.AutoShardPolicy.AUTO`. This mode will attempt to shard the input dataset by files if the dataset is being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, where each of the workers will read the entire dataset and only process the shard assigned to it. However, if you have less than one input file per worker, we suggest that you disable dataset autosharding across workers by setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be `tf.data.experimental.AutoShardPolicy.OFF`. By default, this method adds a prefetch transformation at the end of the user provided `tf.data.Dataset` instance. The argument to the prefetch transformation which is `buffer_size` is equal to the number of replicas in sync. If the above batch splitting and dataset sharding logic is undesirable, please use `tf.distribute.Strategy.distribute_datasets_from_function` instead, which does not do any automatic batching or sharding for you. Note: If you are using TPUStrategy, the order in which the data is processed by the workers when using `tf.distribute.Strategy.experimental_distribute_dataset` or `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed. This is typically required if you are using `tf.distribute` to scale prediction. You can however insert an index for each element in the batch and order outputs accordingly. Refer to [this snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) for an example of how to order outputs. Note: Stateful dataset transformations are currently not supported with `tf.distribute.experimental_distribute_dataset` or `tf.distribute.distribute_datasets_from_function`. Any stateful ops that the dataset may have are currently ignored. For example, if your dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, then you have a dataset graph that depends on state (i.e the random seed) on the local machine where the python process is being executed. For a tutorial on more usage and properties of this method, refer to the [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). Args: dataset: `tf.data.Dataset` that will be sharded across all replicas using the rules stated above. options: `tf.distribute.InputOptions` used to control options on how this dataset is distributed. Returns: A `tf.distribute.DistributedDataset`. """ # pylint: enable=line-too-long return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access def distribute_datasets_from_function(self, dataset_fn, options=None): # pylint: disable=line-too-long """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. The argument `dataset_fn` that users pass in is an input function that has a `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` instance. It is expected that the returned dataset from `dataset_fn` is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. `tf.distribute.Strategy.distribute_datasets_from_function` does not batch or shard the `tf.data.Dataset` instance returned from the input function. `dataset_fn` will be called on the CPU device of each of the workers and each generates a dataset where every replica on that worker will dequeue one batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued from the `Dataset` every step). This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, `tf.distribute.experimental_distribute_dataset` does batching and sharding for you.) For example, where `experimental_distribute_dataset` is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in `experimental_distribute_dataset`). In cases where the dataset is infinite, this sharding can be done by creating dataset replicas that differ only in their random seed. The `dataset_fn` should take an `tf.distribute.InputContext` instance where information about batching and input replication can be accessed. You can use `element_spec` property of the `tf.distribute.DistributedDataset` returned by this API to query the `tf.TypeSpec` of the elements returned by the iterator. This can be used to set the `input_signature` property of a `tf.function`. Follow `tf.distribute.DistributedDataset.element_spec` to see an example. IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a per-replica batch size, unlike `experimental_distribute_dataset`, which uses the global batch size. This may be computed using `input_context.get_per_replica_batch_size`. Note: If you are using TPUStrategy, the order in which the data is processed by the workers when using `tf.distribute.Strategy.experimental_distribute_dataset` or `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed. This is typically required if you are using `tf.distribute` to scale prediction. You can however insert an index for each element in the batch and order outputs accordingly. Refer to [this snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) for an example of how to order outputs. Note: Stateful dataset transformations are currently not supported with `tf.distribute.experimental_distribute_dataset` or `tf.distribute.distribute_datasets_from_function`. Any stateful ops that the dataset may have are currently ignored. For example, if your dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, then you have a dataset graph that depends on state (i.e the random seed) on the local machine where the python process is being executed. For a tutorial on more usage and properties of this method, refer to the [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). Args: dataset_fn: A function taking a `tf.distribute.InputContext` instance and returning a `tf.data.Dataset`. options: `tf.distribute.InputOptions` used to control options on how this dataset is distributed. Returns: A `tf.distribute.DistributedDataset`. """ # pylint: enable=line-too-long return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access dataset_fn, options) # TODO(b/162776748): Remove deprecated symbol. @doc_controls.do_not_doc_inheritable @deprecation.deprecated(None, "rename to distribute_datasets_from_function") def experimental_distribute_datasets_from_function(self, dataset_fn, options=None): return self.distribute_datasets_from_function(dataset_fn, options) def run(self, fn, args=(), kwargs=None, options=None): """Invokes `fn` on each replica, with the given arguments. This method is the primary way to distribute your computation with a tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs` have `tf.distribute.DistributedValues`, such as those produced by a `tf.distribute.DistributedDataset` from `tf.distribute.Strategy.experimental_distribute_dataset` or `tf.distribute.Strategy.distribute_datasets_from_function`, when `fn` is executed on a particular replica, it will be executed with the component of `tf.distribute.DistributedValues` that correspond to that replica. `fn` is invoked under a replica context. `fn` may call `tf.distribute.get_replica_context()` to access members such as `all_reduce`. Please see the module-level docstring of tf.distribute for the concept of replica context. All arguments in `args` or `kwargs` should either be Python values of a nested structure of tensors, e.g. a list of tensors, in which case `args` and `kwargs` will be passed to the `fn` invoked on each replica. Or `args` or `kwargs` can be `tf.distribute.DistributedValues` containing tensors or composite tensors, i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call will get the component of a `tf.distribute.DistributedValues` corresponding to its replica. IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and whether eager execution is enabled, `fn` may be called one or more times. If `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is called inside a `tf.function` (eager execution is disabled inside a `tf.function` by default), `fn` is called once per replica to generate a Tensorflow graph, which will then be reused for execution with new inputs. Otherwise, if eager execution is enabled, `fn` will be called once per replica every step just like regular python code. Example usage: 1. Constant tensor input. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> tensor_input = tf.constant(3.0) >>> @tf.function ... def replica_fn(input): ... return input*2.0 >>> result = strategy.run(replica_fn, args=(tensor_input,)) >>> result PerReplica:{ 0: , 1: } 2. DistributedValues input. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> @tf.function ... def run(): ... def value_fn(value_context): ... return value_context.num_replicas_in_sync ... distributed_values = ( ... strategy.experimental_distribute_values_from_function( ... value_fn)) ... def replica_fn2(input): ... return input*2 ... return strategy.run(replica_fn2, args=(distributed_values,)) >>> result = run() >>> result 3. Use `tf.distribute.ReplicaContext` to allreduce values. >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"]) >>> @tf.function ... def run(): ... def value_fn(value_context): ... return tf.constant(value_context.replica_id_in_sync_group) ... distributed_values = ( ... strategy.experimental_distribute_values_from_function( ... value_fn)) ... def replica_fn(input): ... return tf.distribute.get_replica_context().all_reduce("sum", input) ... return strategy.run(replica_fn, args=(distributed_values,)) >>> result = run() >>> result PerReplica:{ 0: , 1: } Args: fn: The function to run on each replica. args: Optional positional arguments to `fn`. Its element can be a Python value, a tensor or a `tf.distribute.DistributedValues`. kwargs: Optional keyword arguments to `fn`. Its element can be a Python value, a tensor or a `tf.distribute.DistributedValues`. options: An optional instance of `tf.distribute.RunOptions` specifying the options to run `fn`. Returns: Merged return value of `fn` across replicas. The structure of the return value is the same as the return value from `fn`. Each element in the structure can either be `tf.distribute.DistributedValues`, `Tensor` objects, or `Tensor`s (for example, if running on a single replica). """ del options if not isinstance(args, (list, tuple)): raise ValueError( "positional args must be a list or tuple, got {}".format(type(args))) with self.scope(): # tf.distribute supports Eager functions, so AutoGraph should not be # applied when when the caller is also in Eager mode. fn = autograph.tf_convert( fn, autograph_ctx.control_status_ctx(), convert_by_default=False) return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) def reduce(self, reduce_op, value, axis): """Reduce `value` across replicas and return result on current device. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def step_fn(): ... i = tf.distribute.get_replica_context().replica_id_in_sync_group ... return tf.identity(i) >>> >>> per_replica_result = strategy.run(step_fn) >>> total = strategy.reduce("SUM", per_replica_result, axis=None) >>> total To see how this would look with multiple replicas, consider the same example with MirroredStrategy with 2 GPUs: ```python strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) def step_fn(): i = tf.distribute.get_replica_context().replica_id_in_sync_group return tf.identity(i) per_replica_result = strategy.run(step_fn) # Check devices on which per replica result is: strategy.experimental_local_results(per_replica_result)[0].device # /job:localhost/replica:0/task:0/device:GPU:0 strategy.experimental_local_results(per_replica_result)[1].device # /job:localhost/replica:0/task:0/device:GPU:1 total = strategy.reduce("SUM", per_replica_result, axis=None) # Check device on which reduced result is: total.device # /job:localhost/replica:0/task:0/device:CPU:0 ``` This API is typically used for aggregating the results returned from different replicas, for reporting etc. For example, loss computed from different replicas can be averaged using this API before printing. Note: The result is copied to the "current" device - which would typically be the CPU of the worker on which the program is running. For `TPUStrategy`, it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, this is CPU of each worker. There are a number of different tf.distribute APIs for reducing values across replicas: * `tf.distribute.ReplicaContext.all_reduce`: This differs from `Strategy.reduce` in that it is for replica context and does not copy the results to the host device. `all_reduce` should be typically used for reductions inside the training step such as gradients. * `tf.distribute.StrategyExtended.reduce_to` and `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more advanced versions of `Strategy.reduce` as they allow customizing the destination of the result. They are also called in cross replica context. _What should axis be?_ Given a per-replica value returned by `run`, say a per-example loss, the batch will be divided across all the replicas. This function allows you to aggregate across replicas and optionally also across batch elements by specifying the axis parameter accordingly. For example, if you have a global batch size of 8 and 2 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful when each replica is computing a scalar or some other value that doesn't have a "batch" dimension (like a gradient or loss). ``` strategy.reduce("sum", per_replica_result, axis=None) ``` Sometimes, you will want to aggregate across both the global batch _and_ all replicas. You can get this behavior by specifying the batch dimension as the `axis`, typically `axis=0`. In this case it would return a scalar `0+1+2+3+4+5+6+7`. ``` strategy.reduce("sum", per_replica_result, axis=0) ``` If there is a last partial batch, you will need to specify an axis so that the resulting shape is consistent across replicas. So if the last batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you would get a shape mismatch unless you specify `axis=0`. If you specify `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct denominator of 6. Contrast this with computing `reduce_mean` to get a scalar value on each replica and this function to average those means, which will weigh some values `1/8` and others `1/4`. Args: reduce_op: a `tf.distribute.ReduceOp` value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN". value: a `tf.distribute.DistributedValues` instance, e.g. returned by `Strategy.run`, to be combined into a single tensor. It can also be a regular tensor when used with `OneDeviceStrategy` or default strategy. axis: specifies the dimension to reduce along within each replica's tensor. Should typically be set to the batch dimension, or `None` to only reduce across replicas (e.g. if the tensor has no batch dimension). Returns: A `Tensor`. """ # TODO(josh11b): support `value` being a nest. _require_cross_replica_or_default_context_extended(self._extended) if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) if axis is None: return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.SUM: def reduce_sum(v): return math_ops.reduce_sum(v, axis=axis) if eager_context.executing_eagerly(): # As some strategies (e.g. TPUStrategy) doesn't support pure eager # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be # run from eager mode. Cache the tf.function by `axis` to avoid the # same function to be traced again. if axis not in self._reduce_sum_fns: def reduce_sum_fn(v): return self.run(reduce_sum, args=(v,)) self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) value = self._reduce_sum_fns[axis](value) else: value = self.run(reduce_sum, args=(value,)) return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access if reduce_op != reduce_util.ReduceOp.MEAN: raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " "not: %r" % reduce_op) # TODO(josh11b): Support list/tuple and tensor axis values. if not isinstance(axis, six.integer_types): raise TypeError("Expected `axis` to be an integer not: %r" % axis) def mean_reduce_helper(v, axis=axis): """Computes the numerator and denominator on each replica.""" numer = math_ops.reduce_sum(v, axis=axis) if v.shape.rank is not None: # Note(joshl): We support axis < 0 to be consistent with the # tf.math.reduce_* operations. if axis < 0: if axis + v.shape.rank < 0: raise ValueError( "`axis` = %r out of range for `value` with rank %d" % (axis, v.shape.rank)) axis += v.shape.rank elif axis >= v.shape.rank: raise ValueError( "`axis` = %r out of range for `value` with rank %d" % (axis, v.shape.rank)) # TF v2 returns `None` for unknown dimensions and an integer for # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) # or tensor_shape.Dimension(integer). `dimension_value` hides this # difference, always returning `None` or an integer. dim = tensor_shape.dimension_value(v.shape[axis]) if dim is not None: # By returning a python value in the static shape case, we can # maybe get a fast path for reducing the denominator. # TODO(b/151871486): Remove array_ops.identity after we fallback to # simple reduction if inputs are all on CPU. return numer, array_ops.identity( constant_op.constant(dim, dtype=dtypes.int64)) elif axis < 0: axis = axis + array_ops.rank(v) # TODO(b/151871486): Remove array_ops.identity after we fallback to simple # reduction if inputs are all on CPU. denom = array_ops.identity( array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) # TODO(josh11b): Should we cast denom to v.dtype here instead of after the # reduce is complete? return numer, denom if eager_context.executing_eagerly(): # As some strategies (e.g. TPUStrategy) doesn't support pure eager # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can # be run from eager mode. Cache the tf.function by `axis` to avoid the # same function to be traced again. if axis not in self._mean_reduce_helper_fns: def mean_reduce_fn(v): return self.run(mean_reduce_helper, args=(v,)) self._mean_reduce_helper_fns[axis] = def_function.function( mean_reduce_fn) numer, denom = self._mean_reduce_helper_fns[axis](value) else: numer, denom = self.run(mean_reduce_helper, args=(value,)) # TODO(josh11b): Should batch reduce here instead of doing two. numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access denom = math_ops.cast(denom, numer.dtype) return math_ops.truediv(numer, denom) @doc_controls.do_not_doc_inheritable # DEPRECATED def unwrap(self, value): """Returns the list of all local per-replica values contained in `value`. DEPRECATED: Please use `experimental_local_results` instead. Note: This only returns values on the workers initiated by this client. When using a `tf.distribute.Strategy` like `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker will be its own client, and this function will only return values computed on that worker. Args: value: A value returned by `experimental_run()`, `extended.call_for_each_replica()`, or a variable created in `scope`. Returns: A tuple of values contained in `value`. If `value` represents a single value, this returns `(value,).` """ return self._extended._local_results(value) # pylint: disable=protected-access def experimental_local_results(self, value): """Returns the list of all local per-replica values contained in `value`. Note: This only returns values on the worker initiated by this client. When using a `tf.distribute.Strategy` like `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker will be its own client, and this function will only return values computed on that worker. Args: value: A value returned by `experimental_run()`, `run()`, `extended.call_for_each_replica()`, or a variable created in `scope`. Returns: A tuple of values contained in `value`. If `value` represents a single value, this returns `(value,).` """ return self._extended._local_results(value) # pylint: disable=protected-access @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only def group(self, value, name=None): """Shortcut for `tf.group(self.experimental_local_results(value))`.""" return self._extended._group(value, name) # pylint: disable=protected-access @property def num_replicas_in_sync(self): """Returns number of replicas over which gradients are aggregated.""" return self._extended._num_replicas_in_sync # pylint: disable=protected-access @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string def configure(self, session_config=None, cluster_spec=None, task_type=None, task_id=None): # pylint: disable=g-doc-return-or-yield,g-doc-args """DEPRECATED: use `update_config_proto` instead. Configures the strategy class. DEPRECATED: This method's functionality has been split into the strategy constructor and `update_config_proto`. In the future, we will allow passing cluster and config_proto to the constructor to configure the strategy. And `update_config_proto` can be used to update the config_proto based on the specific strategy. """ return self._extended._configure( # pylint: disable=protected-access session_config, cluster_spec, task_type, task_id) @doc_controls.do_not_generate_docs # DEPRECATED def update_config_proto(self, config_proto): """DEPRECATED TF 1.x ONLY.""" return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access def __deepcopy__(self, memo): # First do a regular deepcopy of `self`. cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, copy.deepcopy(v, memo)) # One little fix-up: we want `result._extended` to reference `result` # instead of `self`. result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access return result def __copy__(self): raise RuntimeError("Must only deepcopy DistributionStrategy.") @property def cluster_resolver(self): """Returns the cluster resolver associated with this strategy. In general, when using a multi-worker `tf.distribute` strategy such as `tf.distribute.experimental.MultiWorkerMirroredStrategy` or `tf.distribute.TPUStrategy()`, there is a `tf.distribute.cluster_resolver.ClusterResolver` associated with the strategy used, and such an instance is returned by this property. Strategies that intend to have an associated `tf.distribute.cluster_resolver.ClusterResolver` must set the relevant attribute, or override this property; otherwise, `None` is returned by default. Those strategies should also provide information regarding what is returned by this property. Single-worker strategies usually do not have a `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this property will return `None`. The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the user needs to access information such as the cluster spec, task type or task id. For example, ```python os.environ['TF_CONFIG'] = json.dumps({ 'cluster': { 'worker': ["localhost:12345", "localhost:23456"], 'ps': ["localhost:34567"] }, 'task': {'type': 'worker', 'index': 0} }) # This implicitly uses TF_CONFIG for the cluster and current task info. strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() ... if strategy.cluster_resolver.task_type == 'worker': # Perform something that's only applicable on workers. Since we set this # as a worker above, this block will run on this particular instance. elif strategy.cluster_resolver.task_type == 'ps': # Perform something that's only applicable on parameter servers. Since we # set this as a worker above, this block will not run on this particular # instance. ``` For more information, please see `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. Returns: The cluster resolver associated with this strategy. Returns `None` if a cluster resolver is not applicable or available in this strategy. """ if hasattr(self.extended, "_cluster_resolver"): return self.extended._cluster_resolver # pylint: disable=protected-access return None @tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring class Strategy(StrategyBase): __doc__ = StrategyBase.__doc__ def experimental_distribute_values_from_function(self, value_fn): """Generates `tf.distribute.DistributedValues` from `value_fn`. This function is to generate `tf.distribute.DistributedValues` to pass into `run`, `reduce`, or other methods that take distributed values when not using datasets. Args: value_fn: The function to run to generate values. It is called for each replica with `tf.distribute.ValueContext` as the sole argument. It must return a Tensor or a type that can be converted to a Tensor. Returns: A `tf.distribute.DistributedValues` containing a value for each replica. Example usage: 1. Return constant value per replica: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def value_fn(ctx): ... return tf.constant(1.) >>> distributed_values = ( ... strategy.experimental_distribute_values_from_function( ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result (, ) 2. Distribute values in array based on replica_id: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> array_value = np.array([3., 2., 1.]) >>> def value_fn(ctx): ... return array_value[ctx.replica_id_in_sync_group] >>> distributed_values = ( ... strategy.experimental_distribute_values_from_function( ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result (3.0, 2.0) 3. Specify values using num_replicas_in_sync: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def value_fn(ctx): ... return ctx.num_replicas_in_sync >>> distributed_values = ( ... strategy.experimental_distribute_values_from_function( ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result (2, 2) 4. Place values on devices and distribute: ``` strategy = tf.distribute.TPUStrategy() worker_devices = strategy.extended.worker_devices multiple_values = [] for i in range(strategy.num_replicas_in_sync): with tf.device(worker_devices[i]): multiple_values.append(tf.constant(1.0)) def value_fn(ctx): return multiple_values[ctx.replica_id_in_sync_group] distributed_values = strategy. experimental_distribute_values_from_function( value_fn) ``` """ return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access value_fn) def gather(self, value, axis): # pylint: disable=line-too-long, protected-access """Gather `value` across replicas along `axis` to the current device. Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like object `value`, this API gathers and concatenates `value` across replicas along the `axis`-th dimension. The result is copied to the "current" device - which would typically be the CPU of the worker on which the program is running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client `MultiWorkerMirroredStrategy`, this is CPU of each worker. This API can only be called in the cross-replica context. For a counterpart in the replica context, see `tf.distribute.ReplicaContext.all_gather`. Note: For all strategies except `tf.distribute.TPUStrategy`, the input `value` on different replicas must have the same rank, and their shapes must be the same in all dimensions except the `axis`-th dimension. In other words, their shapes cannot be different in a dimension `d` where `d` does not equal to the `axis` argument. For example, given a `tf.distribute.DistributedValues` with component tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, all tensors must have exactly the same rank and same shape. Note: Given a `tf.distribute.DistributedValues` `value`, its component tensors must have a non-zero rank. Otherwise, consider using `tf.expand_dims` before gathering them. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> # A DistributedValues with component tensor of shape (2, 1) on each replica ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) >>> @tf.function ... def run(): ... return strategy.gather(distributed_values, axis=0) >>> run() Consider the following example for more combinations: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) >>> @tf.function ... def run(axis): ... return strategy.gather(distributed_values, axis=axis) >>> axis=0 >>> run(axis) >>> axis=1 >>> run(axis) >>> axis=2 >>> run(axis) Args: value: a `tf.distribute.DistributedValues` instance, e.g. returned by `Strategy.run`, to be combined into a single tensor. It can also be a regular tensor when used with `tf.distribute.OneDeviceStrategy` or the default strategy. The tensors that constitute the DistributedValues can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). Returns: A `Tensor` that's the concatenation of `value` across replicas along `axis` dimension. """ # pylint: enable=line-too-long error_message = ("tf.distribute.Strategy.gather method requires " "cross-replica context, use " "get_replica_context().all_gather() instead") _require_cross_replica_or_default_context_extended(self._extended, error_message) dst = device_util.current( ) or self._extended._default_device or "/device:CPU:0" if isinstance(value, ops.IndexedSlices): raise NotImplementedError("gather does not support IndexedSlices") return self._extended._local_results( self._extended._gather_to(value, dst, axis))[0] # TF v1.x version has additional deprecated APIs @tf_export(v1=["distribute.Strategy"]) class StrategyV1(StrategyBase): """A list of devices with a state & compute distribution policy. See [the guide](https://www.tensorflow.org/guide/distribute_strategy) for overview and examples. Note: Not all `tf.distribute.Strategy` implementations currently support TensorFlow's partitioned variables (where a single variable is split across multiple devices) at this time. """ def make_dataset_iterator(self, dataset): """Makes an iterator for input provided via `dataset`. DEPRECATED: This method is not available in TF 2.x. Data from the given dataset will be distributed evenly across all the compute replicas. We will assume that the input dataset is batched by the global batch size. With this assumption, we will make a best effort to divide each batch across all the replicas (one or more workers). If this effort fails, an error will be thrown, and the user should instead use `make_input_fn_iterator` which provides more control to the user, and does not try to divide a batch across replicas. The user could also use `make_input_fn_iterator` if they want to customize which input is fed to which replica/worker etc. Args: dataset: `tf.data.Dataset` that will be distributed evenly across all replicas. Returns: An `tf.distribute.InputIterator` which returns inputs for each step of the computation. User should call `initialize` on the returned iterator. """ return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation input_fn, replication_mode=InputReplicationMode.PER_WORKER): """Returns an iterator split across replicas created from an input function. DEPRECATED: This method is not available in TF 2.x. The `input_fn` should take an `tf.distribute.InputContext` object where information about batching and input sharding can be accessed: ``` def input_fn(input_context): batch_size = input_context.get_per_replica_batch_size(global_batch_size) d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) return d.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) with strategy.scope(): iterator = strategy.make_input_fn_iterator(input_fn) replica_results = strategy.experimental_run(replica_fn, iterator) ``` The `tf.data.Dataset` returned by `input_fn` should have a per-replica batch size, which may be computed using `input_context.get_per_replica_batch_size`. Args: input_fn: A function taking a `tf.distribute.InputContext` object and returning a `tf.data.Dataset`. replication_mode: an enum value of `tf.distribute.InputReplicationMode`. Only `PER_WORKER` is supported currently, which means there will be a single call to `input_fn` per worker. Replicas will dequeue from the local `tf.data.Dataset` on their worker. Returns: An iterator object that should first be `.initialize()`-ed. It may then either be passed to `strategy.experimental_run()` or you can `iterator.get_next()` to get the next value to pass to `strategy.extended.call_for_each_replica()`. """ return super(StrategyV1, self).make_input_fn_iterator( input_fn, replication_mode) def experimental_make_numpy_dataset(self, numpy_input, session=None): """Makes a tf.data.Dataset for input provided via a numpy array. This avoids adding `numpy_input` as a large constant in the graph, and copies the data to the machine or machines that will be processing the input. Note that you will likely need to use tf.distribute.Strategy.experimental_distribute_dataset with the returned dataset to further distribute it with the strategy. Example: ``` numpy_input = np.ones([10], dtype=np.float32) dataset = strategy.experimental_make_numpy_dataset(numpy_input) dist_dataset = strategy.experimental_distribute_dataset(dataset) ``` Args: numpy_input: A nest of NumPy input arrays that will be converted into a dataset. Note that lists of Numpy arrays are stacked, as that is normal `tf.data.Dataset` behavior. session: (TensorFlow v1.x graph execution only) A session used for initialization. Returns: A `tf.data.Dataset` representing `numpy_input`. """ return self.extended.experimental_make_numpy_dataset( numpy_input, session=session) def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation """Runs ops in `fn` on each replica, with inputs from `input_iterator`. DEPRECATED: This method is not available in TF 2.x. Please switch to using `run` instead. When eager execution is enabled, executes ops specified by `fn` on each replica. Otherwise, builds a graph to execute the ops on each replica. Each replica will take a single, different input from the inputs provided by one `get_next` call on the input iterator. `fn` may call `tf.distribute.get_replica_context()` to access members such as `replica_id_in_sync_group`. IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being used, and whether eager execution is enabled, `fn` may be called one or more times (once for each replica). Args: fn: The function to run. The inputs to the function must match the outputs of `input_iterator.get_next()`. The output must be a `tf.nest` of `Tensor`s. input_iterator: (Optional) input iterator from which the inputs are taken. Returns: Merged return value of `fn` across replicas. The structure of the return value is the same as the return value from `fn`. Each element in the structure can either be `PerReplica` (if the values are unsynchronized), `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a single replica). """ return super(StrategyV1, self).experimental_run( fn, input_iterator) def reduce(self, reduce_op, value, axis=None): return super(StrategyV1, self).reduce(reduce_op, value, axis) reduce.__doc__ = StrategyBase.reduce.__doc__ def update_config_proto(self, config_proto): """Returns a copy of `config_proto` modified for use with this strategy. DEPRECATED: This method is not available in TF 2.x. The updated config has something needed to run a strategy, e.g. configuration to run collective ops, or device filters to improve distributed training performance. Args: config_proto: a `tf.ConfigProto` object. Returns: The updated copy of the `config_proto`. """ return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access # NOTE(josh11b): For any strategy that needs to support tf.compat.v1, # instead descend from StrategyExtendedV1. @tf_export("distribute.StrategyExtended", v1=[]) class StrategyExtendedV2(object): """Additional APIs for algorithms that need to be distribution-aware. Note: For most usage of `tf.distribute.Strategy`, there should be no need to call these methods, since TensorFlow libraries (such as optimizers) already call these methods when needed on your behalf. Some common use cases of functions on this page: * _Locality_ `tf.distribute.DistributedValues` can have the same _locality_ as a _distributed variable_, which leads to a mirrored value residing on the same devices as the variable (as opposed to the compute devices). Such values may be passed to a call to `tf.distribute.StrategyExtended.update` to update the value of a variable. You may use `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the same locality as another variable. You may convert a "PerReplica" value to a variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or `tf.distribute.StrategyExtended.batch_reduce_to`. * _How to update a distributed variable_ A distributed variable is variables created on multiple devices. As discussed in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to: 1. In your function passed to `tf.distribute.Strategy.run`, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable. 2. Switch to cross-replica mode by calling `tf.distribute.get_replica_context().merge_call()` with the updates and variables as arguments. 3. Call `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` (for a list of variables) to sum the updates. 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update its value. Steps 2 through 4 are done automatically by class `tf.keras.optimizers.Optimizer` if you call its `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. In fact, a higher-level solution to update a distributed variable is by calling `assign` on the variable as you would do to a regular `tf.Variable`. You can call the method in both _replica context_ and _cross-replica context_. For a _mirrored variable_, calling `assign` in _replica context_ requires you to specify the `aggregation` type in the variable constructor. In that case, the context switching and sync described in steps 2 through 4 are handled for you. If you call `assign` on _mirrored variable_ in _cross-replica context_, you can only assign a single value or assign values from another mirrored variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead variable_, in _replica context_, you can simply call `assign` on it and no aggregation happens under the hood. In _cross-replica context_, you can only assign a single value to a SyncOnRead variable. One example case is restoring from a checkpoint: if the `aggregation` type of the variable is `tf.VariableAggregation.SUM`, it is assumed that replica values were added before checkpointing, so at the time of restoring, the value is divided by the number of replicas and then assigned to each replica; if the `aggregation` type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica directly. """ def __init__(self, container_strategy): self._container_strategy_weakref = weakref.ref(container_strategy) self._default_device = None # This property is used to determine if we should set drop_remainder=True # when creating Datasets from numpy array inputs. self._require_static_shapes = False def _container_strategy(self): """Get the containing `tf.distribute.Strategy`. This should not generally be needed except when creating a new `ReplicaContext` and to validate that the caller is in the correct `scope()`. Returns: The `tf.distribute.Strategy` such that `strategy.extended` is `self`. """ container_strategy = self._container_strategy_weakref() assert container_strategy is not None return container_strategy def _scope(self, strategy): """Implementation of tf.distribute.Strategy.scope().""" def creator_with_resource_vars(next_creator, **kwargs): """Variable creator to use in `_CurrentDistributionContext`.""" _require_strategy_scope_extended(self) kwargs["use_resource"] = True kwargs["distribute_strategy"] = strategy # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid # dereferencing a `Tensor` that is without a `name`. We still need to # propagate the metadata it's holding. if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): checkpoint_restore_uid = kwargs[ "initial_value"].checkpoint_position.restore_uid kwargs["initial_value"] = kwargs["initial_value"].wrapped_value elif isinstance(kwargs["initial_value"], trackable.CheckpointInitialValueCallable): checkpoint_restore_uid = kwargs[ "initial_value"].checkpoint_position.restore_uid else: checkpoint_restore_uid = None created = self._create_variable(next_creator, **kwargs) if checkpoint_restore_uid is not None: # pylint: disable=protected-access # Let the checkpointing infrastructure know that the variable was # already restored so it doesn't waste memory loading the value again. # In this case of CheckpointInitialValueCallable this may already be # done by the final variable creator, but it doesn't hurt to do it # again. created._maybe_initialize_trackable() created._update_uid = checkpoint_restore_uid # pylint: enable=protected-access return created def distributed_getter(getter, *args, **kwargs): if not self._allow_variable_partition(): if kwargs.pop("partitioner", None) is not None: tf_logging.log_first_n( tf_logging.WARN, "Partitioned variables are disabled when using " "current tf.distribute.Strategy.", 1) return getter(*args, **kwargs) return _CurrentDistributionContext( strategy, variable_scope.variable_creator_scope(creator_with_resource_vars), variable_scope.variable_scope( variable_scope.get_variable_scope(), custom_getter=distributed_getter), self._default_device) def _allow_variable_partition(self): return False def _create_variable(self, next_creator, **kwargs): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") def variable_created_in_scope(self, v): """Tests whether `v` was created while this strategy scope was active. Variables created inside the strategy scope are "owned" by it: >>> strategy = tf.distribute.MirroredStrategy() >>> with strategy.scope(): ... v = tf.Variable(1.) >>> strategy.extended.variable_created_in_scope(v) True Variables created outside the strategy are not owned by it: >>> strategy = tf.distribute.MirroredStrategy() >>> v = tf.Variable(1.) >>> strategy.extended.variable_created_in_scope(v) False Args: v: A `tf.Variable` instance. Returns: True if `v` was created inside the scope, False if not. """ return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access def colocate_vars_with(self, colocate_with_variable): """Scope that controls which devices variables will be created on. No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.compat.v1.colocate_with() scope). This may only be used inside `self.scope()`. Example usage: ``` with strategy.scope(): var1 = tf.Variable(...) with strategy.extended.colocate_vars_with(var1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.Variable(...) var3 = tf.Variable(...) def fn(v1, v2, v3): # operates on v1 from var1, v2 from var2, and v3 from var3 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there # too. strategy.extended.update(var1, fn, args=(var2, var3)) ``` Args: colocate_with_variable: A variable created in this strategy's `scope()`. Variables created while in the returned context manager will be on the same set of devices as `colocate_with_variable`. Returns: A context manager. """ def create_colocated_variable(next_creator, **kwargs): _require_strategy_scope_extended(self) kwargs["use_resource"] = True kwargs["colocate_with"] = colocate_with_variable return next_creator(**kwargs) _require_strategy_scope_extended(self) self._validate_colocate_with_variable(colocate_with_variable) return variable_scope.variable_creator_scope(create_colocated_variable) def _validate_colocate_with_variable(self, colocate_with_variable): """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" pass def _make_dataset_iterator(self, dataset): raise NotImplementedError("must be implemented in descendants") def _make_input_fn_iterator(self, input_fn, replication_mode): raise NotImplementedError("must be implemented in descendants") def _experimental_distribute_dataset(self, dataset, options): raise NotImplementedError("must be implemented in descendants") def _distribute_datasets_from_function(self, dataset_fn, options): raise NotImplementedError("must be implemented in descendants") def _experimental_distribute_values_from_function(self, value_fn): raise NotImplementedError("must be implemented in descendants") def _reduce(self, reduce_op, value): # Default implementation until we have an implementation for each strategy. dst = device_util.current() or self._default_device or "/device:CPU:0" return self._local_results(self.reduce_to(reduce_op, value, dst))[0] def reduce_to(self, reduce_op, value, destinations, options=None): """Combine (via e.g. sum or mean) values across replicas. `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed variables. It supports both dense values and `tf.IndexedSlices`. This API currently can only be called in cross-replica context. Other variants to reduce values across replicas are: * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of this API. * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API in replica context. It supports both batched and non-batched all-reduce. * `tf.distribute.Strategy.reduce`: a more convenient method to reduce to the host in cross-replica context. `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can also pass in a `Tensor`, and the destinations will be the device of that tensor. For all-reduce, pass the same to `value` and `destinations`. It can be used in `tf.distribute.ReplicaContext.merge_call` to write code that works for all `tf.distribute.Strategy`. >>> @tf.function ... def step_fn(var): ... ... def merge_fn(strategy, value, var): ... # All-reduce the value. Note that `value` here is a ... # `tf.distribute.DistributedValues`. ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, ... value, destinations=var) ... strategy.extended.update(var, lambda var, value: var.assign(value), ... args=(reduced,)) ... ... value = tf.identity(1.) ... tf.distribute.get_replica_context().merge_call(merge_fn, ... args=(value, var)) >>> >>> def run(strategy): ... with strategy.scope(): ... v = tf.Variable(0.) ... strategy.run(step_fn, args=(v,)) ... return v >>> >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) MirroredVariable:{ 0: , 1: } >>> run(tf.distribute.experimental.CentralStorageStrategy( ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) Args: reduce_op: a `tf.distribute.ReduceOp` value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN". value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a `tf.Tensor` alike object, or a device string. It specifies the devices to reduce to. To perform an all-reduce, pass the same to `value` and `destinations`. Note that if it's a `tf.Variable`, the value is reduced to the devices of that variable, and this method doesn't update the variable. options: a `tf.distribute.experimental.CommunicationOptions`. Options to perform collective operations. This overrides the default options if the `tf.distribute.Strategy` takes one in the constructor. See `tf.distribute.experimental.CommunicationOptions` for details of the options. Returns: A tensor or value reduced to `destinations`. """ if options is None: options = collective_util.Options() _require_cross_replica_or_default_context_extended(self) assert not isinstance(destinations, (list, tuple)) assert not isinstance(reduce_op, variable_scope.VariableAggregation) if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) assert (reduce_op == reduce_util.ReduceOp.SUM or reduce_op == reduce_util.ReduceOp.MEAN) return self._reduce_to(reduce_op, value, destinations, options) def _reduce_to(self, reduce_op, value, destinations, options): raise NotImplementedError("must be implemented in descendants") def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): """Combine multiple `reduce_to` calls into one for faster execution. Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. It's more efficient than reduce each value separately. This API currently can only be called in cross-replica context. Other variants to reduce values across replicas are: * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of this API. * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API in replica context. It supports both batched and non-batched all-reduce. * `tf.distribute.Strategy.reduce`: a more convenient method to reduce to the host in cross-replica context. See `reduce_to` for more information. >>> @tf.function ... def step_fn(var): ... ... def merge_fn(strategy, value, var): ... # All-reduce the value. Note that `value` here is a ... # `tf.distribute.DistributedValues`. ... reduced = strategy.extended.batch_reduce_to( ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] ... strategy.extended.update(var, lambda var, value: var.assign(value), ... args=(reduced,)) ... ... value = tf.identity(1.) ... tf.distribute.get_replica_context().merge_call(merge_fn, ... args=(value, var)) >>> >>> def run(strategy): ... with strategy.scope(): ... v = tf.Variable(0.) ... strategy.run(step_fn, args=(v,)) ... return v >>> >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) MirroredVariable:{ 0: , 1: } >>> run(tf.distribute.experimental.CentralStorageStrategy( ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) Args: reduce_op: a `tf.distribute.ReduceOp` value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN". value_destination_pairs: a sequence of (value, destinations) pairs. See `tf.distribute.Strategy.reduce_to` for descriptions. options: a `tf.distribute.experimental.CommunicationOptions`. Options to perform collective operations. This overrides the default options if the `tf.distribute.Strategy` takes one in the constructor. See `tf.distribute.experimental.CommunicationOptions` for details of the options. Returns: A list of reduced values, one per pair in `value_destination_pairs`. """ if options is None: options = collective_util.Options() _require_cross_replica_or_default_context_extended(self) assert not isinstance(reduce_op, variable_scope.VariableAggregation) if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) return self._batch_reduce_to(reduce_op, value_destination_pairs, options) def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): return [ self.reduce_to(reduce_op, t, destinations=v, options=options) for t, v in value_destination_pairs ] def _gather_to(self, value, destinations, axis, options=None): """Gather `value` across replicas along axis-th dimension to `destinations`. `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like object, along `axis`-th dimension. It supports only dense tensors but NOT sparse tensor. This API can only be called in cross-replica context. Args: value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a `tf.Tensor` alike object, or a device string. It specifies the devices to reduce to. To perform an all-gather, pass the same to `value` and `destinations`. Note that if it's a `tf.Variable`, the value is reduced to the devices of that variable, and this method doesn't update the variable. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). options: a `tf.distribute.experimental.CommunicationOptions`. Options to perform collective operations. This overrides the default options if the `tf.distribute.Strategy` takes one in the constructor. See `tf.distribute.experimental.CommunicationOptions` for details of the options. Returns: A tensor or value gathered to `destinations`. """ _require_cross_replica_or_default_context_extended(self) assert not isinstance(destinations, (list, tuple)) if options is None: options = collective_util.Options() return self._gather_to_implementation(value, destinations, axis, options) def _gather_to_implementation(self, value, destinations, axis, options): raise NotImplementedError("_gather_to must be implemented in descendants") def _batch_gather_to(self, value_destination_pairs, axis, options=None): _require_cross_replica_or_default_context_extended(self) if options is None: options = collective_util.Options() return [ self._gather_to(t, destinations=v, axis=axis, options=options) for t, v in value_destination_pairs ] def update(self, var, fn, args=(), kwargs=None, group=True): """Run `fn` to update `var` using inputs mirrored to the same devices. `tf.distribute.StrategyExtended.update` takes a distributed variable `var` to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It applies `fn` to each component variable of `var` and passes corresponding values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain per-replica values. If they contain mirrored values, they will be unwrapped before calling `fn`. For example, `fn` can be `assign_add` and `args` can be a mirrored DistributedValues where each component contains the value to be added to this mirrored variable `var`. Calling `update` will call `assign_add` on each component variable of `var` with the corresponding tensor value on that device. Example usage: ```python strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 devices with strategy.scope(): v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) def update_fn(v): return v.assign(1.0) result = strategy.extended.update(v, update_fn) # result is # Mirrored:{ # 0: tf.Tensor(1.0, shape=(), dtype=float32), # 1: tf.Tensor(1.0, shape=(), dtype=float32) # } ``` If `var` is mirrored across multiple devices, then this method implements logic as following: ```python results = {} for device, v in var: with tf.device(device): # args and kwargs will be unwrapped if they are mirrored. results[device] = fn(v, *args, **kwargs) return merged(results) ``` Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with `var`. Args: var: Variable, possibly mirrored to multiple devices, to operate on. fn: Function to call. Should take the variable as the first argument. args: Tuple or list. Additional positional arguments to pass to `fn()`. kwargs: Dict with keyword arguments to pass to `fn()`. group: Boolean. Defaults to True. If False, the return value will be unwrapped. Returns: By default, the merged return value of `fn` across all replicas. The merged result has dependencies to make sure that if it is evaluated at all, the side effects (updates) will happen on every replica. If instead "group=False" is specified, this function will return a nest of lists where each list has an element per replica, and the caller is responsible for ensuring all elements are executed. """ _require_cross_replica_or_default_context_extended(self) if kwargs is None: kwargs = {} fn = autograph.tf_convert( fn, autograph_ctx.control_status_ctx(), convert_by_default=False) with self._container_strategy().scope(): return self._update(var, fn, args, kwargs, group) def _update(self, var, fn, args, kwargs, group): raise NotImplementedError("must be implemented in descendants") def _local_results(self, distributed_value): raise NotImplementedError("must be implemented in descendants") def value_container(self, value): """Returns the container that this per-replica `value` belongs to. Args: value: A value returned by `run()` or a variable created in `scope()`. Returns: A container that `value` belongs to. If value does not belong to any container (including the case of container having been destroyed), returns the value itself. `value in experimental_local_results(value_container(value))` will always be true. """ raise NotImplementedError("must be implemented in descendants") def _group(self, value, name=None): """Implementation of `group`.""" value = nest.flatten(self._local_results(value)) if len(value) != 1 or name is not None: return control_flow_ops.group(value, name=name) # Special handling for the common case of one op. v, = value if hasattr(v, "op"): v = v.op return v @property def experimental_require_static_shapes(self): """Returns `True` if static shape is required; `False` otherwise.""" return self._require_static_shapes @property def _num_replicas_in_sync(self): """Returns number of replicas over which gradients are aggregated.""" raise NotImplementedError("must be implemented in descendants") @property def worker_devices(self): """Returns the tuple of all devices used to for compute replica execution. """ # TODO(josh11b): More docstring raise NotImplementedError("must be implemented in descendants") @property def parameter_devices(self): """Returns the tuple of all devices used to place variables.""" # TODO(josh11b): More docstring raise NotImplementedError("must be implemented in descendants") def _configure(self, session_config=None, cluster_spec=None, task_type=None, task_id=None): """Configures the strategy class.""" del session_config, cluster_spec, task_type, task_id def _update_config_proto(self, config_proto): return copy.deepcopy(config_proto) def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings. Multi-worker training refers to the setup where the training is distributed across multiple workers, as opposed to the case where only a local process performs the training. This function is used by higher-level apis such as Keras' `model.fit()` to infer for example whether or not a distribute coordinator should be run, and thus TensorFlow servers should be started for communication with other servers in the cluster, or whether or not saving/restoring checkpoints is relevant for preemption fault tolerance. Subclasses should override this to provide whether the strategy is currently in multi-worker setup. Experimental. Signature and implementation are subject to change. """ raise NotImplementedError("must be implemented in descendants") @tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring class StrategyExtendedV1(StrategyExtendedV2): __doc__ = StrategyExtendedV2.__doc__ def experimental_make_numpy_dataset(self, numpy_input, session=None): """Makes a dataset for input provided via a numpy array. This avoids adding `numpy_input` as a large constant in the graph, and copies the data to the machine or machines that will be processing the input. Args: numpy_input: A nest of NumPy input arrays that will be distributed evenly across all replicas. Note that lists of Numpy arrays are stacked, as that is normal `tf.data.Dataset` behavior. session: (TensorFlow v1.x graph execution only) A session used for initialization. Returns: A `tf.data.Dataset` representing `numpy_input`. """ _require_cross_replica_or_default_context_extended(self) return self._experimental_make_numpy_dataset(numpy_input, session=session) def _experimental_make_numpy_dataset(self, numpy_input, session): raise NotImplementedError("must be implemented in descendants") def broadcast_to(self, tensor, destinations): """Mirror a tensor on one device to all worker devices. Args: tensor: A Tensor value to broadcast. destinations: A mirrored variable or device string specifying the destination devices to copy `tensor` to. Returns: A value mirrored to `destinations` devices. """ assert destinations is not None # from old strategy.broadcast() # TODO(josh11b): More docstring _require_cross_replica_or_default_context_extended(self) assert not isinstance(destinations, (list, tuple)) return self._broadcast_to(tensor, destinations) def _broadcast_to(self, tensor, destinations): raise NotImplementedError("must be implemented in descendants") def experimental_run_steps_on_iterator(self, fn, iterator, iterations=1, initial_loop_values=None): """DEPRECATED: please use `run` instead. Run `fn` with input from `iterator` for `iterations` times. This method can be used to run a step function for training a number of times using input from a dataset. Args: fn: function to run using this distribution strategy. The function must have the following signature: `def fn(context, inputs)`. `context` is an instance of `MultiStepContext` that will be passed when `fn` is run. `context` can be used to specify the outputs to be returned from `fn` by calling `context.set_last_step_output`. It can also be used to capture non tensor outputs by `context.set_non_tensor_output`. See `MultiStepContext` documentation for more information. `inputs` will have same type/structure as `iterator.get_next()`. Typically, `fn` will use `call_for_each_replica` method of the strategy to distribute the computation over multiple replicas. iterator: Iterator of a dataset that represents the input for `fn`. The caller is responsible for initializing the iterator as needed. iterations: (Optional) Number of iterations that `fn` should be run. Defaults to 1. initial_loop_values: (Optional) Initial values to be passed into the loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove initial_loop_values argument when we have a mechanism to infer the outputs of `fn`. Returns: Returns the `MultiStepContext` object which has the following properties, among other things: - run_op: An op that runs `fn` `iterations` times. - last_step_outputs: A dictionary containing tensors set using `context.set_last_step_output`. Evaluating this returns the value of the tensors after the last iteration. - non_tensor_outputs: A dictionary containing anything that was set by `fn` by calling `context.set_non_tensor_output`. """ _require_cross_replica_or_default_context_extended(self) with self._container_strategy().scope(): return self._experimental_run_steps_on_iterator(fn, iterator, iterations, initial_loop_values) def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values): raise NotImplementedError("must be implemented in descendants") def call_for_each_replica(self, fn, args=(), kwargs=None): """Run `fn` once per replica. `fn` may call `tf.get_replica_context()` to access methods such as `replica_id_in_sync_group` and `merge_call()`. `merge_call()` is used to communicate between the replicas and re-enter the cross-replica context. All replicas pause their execution having encountered a `merge_call()` call. After that the `merge_fn`-function is executed. Its results are then unwrapped and given back to each replica call. After that execution resumes until `fn` is complete or encounters another `merge_call()`. Example: ```python # Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.run(fn, args=[3]) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results)) ``` Args: fn: function to run (will be run once per replica). args: Tuple or list with positional arguments for `fn`. kwargs: Dict with keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. """ _require_cross_replica_or_default_context_extended(self) if kwargs is None: kwargs = {} with self._container_strategy().scope(): return self._call_for_each_replica(fn, args, kwargs) def _call_for_each_replica(self, fn, args, kwargs): raise NotImplementedError("must be implemented in descendants") def read_var(self, v): """Reads the value of a variable. Returns the aggregate value of a replica-local variable, or the (read-only) value of any other variable. Args: v: A variable allocated within the scope of this `tf.distribute.Strategy`. Returns: A tensor representing the value of `v`, aggregated across replicas if necessary. """ raise NotImplementedError("must be implemented in descendants") def update_non_slot( self, colocate_with, fn, args=(), kwargs=None, group=True): """Runs `fn(*args, **kwargs)` on `colocate_with` devices. Used to update non-slot variables. DEPRECATED: TF 1.x ONLY. Args: colocate_with: Devices returned by `non_slot_devices()`. fn: Function to execute. args: Tuple or list. Positional arguments to pass to `fn()`. kwargs: Dict with keyword arguments to pass to `fn()`. group: Boolean. Defaults to True. If False, the return value will be unwrapped. Returns: Return value of `fn`, possibly merged across devices. """ _require_cross_replica_or_default_context_extended(self) if kwargs is None: kwargs = {} fn = autograph.tf_convert( fn, autograph_ctx.control_status_ctx(), convert_by_default=False) with self._container_strategy().scope(): return self._update_non_slot(colocate_with, fn, args, kwargs, group) def _update_non_slot(self, colocate_with, fn, args, kwargs, group): raise NotImplementedError("must be implemented in descendants") def non_slot_devices(self, var_list): """Device(s) for non-slot variables. DEPRECATED: TF 1.x ONLY. This method returns non-slot devices where non-slot variables are placed. Users can create non-slot variables on these devices by using a block: ```python with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): ... ``` Args: var_list: The list of variables being optimized, needed with the default `tf.distribute.Strategy`. Returns: A sequence of devices for non-slot variables. """ raise NotImplementedError("must be implemented in descendants") @property def experimental_between_graph(self): """Whether the strategy uses between-graph replication or not. This is expected to return a constant value that will not be changed throughout its life cycle. """ raise NotImplementedError("must be implemented in descendants") @property def experimental_should_init(self): """Whether initialization is needed.""" raise NotImplementedError("must be implemented in descendants") @property def should_checkpoint(self): """Whether checkpointing is needed.""" raise NotImplementedError("must be implemented in descendants") @property def should_save_summary(self): """Whether saving summaries is needed.""" raise NotImplementedError("must be implemented in descendants") # A note about the difference between the context managers # `ReplicaContext` (defined here) and `_CurrentDistributionContext` # (defined above) used by `tf.distribute.Strategy.scope()`: # # * a ReplicaContext is only present during a `run()` # call (except during a `merge_run` call) and in such a scope it # will be returned by calls to `get_replica_context()`. Implementers of new # Strategy descendants will frequently also need to # define a descendant of ReplicaContext, and are responsible for # entering and exiting this context. # # * Strategy.scope() sets up a variable_creator scope that # changes variable creation calls (e.g. to make mirrored # variables). This is intended as an outer scope that users enter once # around their model creation and graph definition. There is no # anticipated need to define descendants of _CurrentDistributionContext. # It sets the current Strategy for purposes of # `get_strategy()` and `has_strategy()` # and switches the thread mode to a "cross-replica context". class ReplicaContextBase(object): """A class with a collection of APIs that can be called in a replica context. You can use `tf.distribute.get_replica_context` to get an instance of `ReplicaContext`, which can only be called inside the function passed to `tf.distribute.Strategy.run`. >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) >>> def func(): ... replica_context = tf.distribute.get_replica_context() ... return replica_context.replica_id_in_sync_group >>> strategy.run(func) PerReplica:{ 0: , 1: } """ def __init__(self, strategy, replica_id_in_sync_group): """Creates a ReplicaContext. Args: strategy: A `tf.distribute.Strategy`. replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an integer whenever possible to avoid issues with nested `tf.function`. It accepts a `Tensor` only to be compatible with `tpu.replicate`. """ self._strategy = strategy self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access self) if not (replica_id_in_sync_group is None or tensor_util.is_tensor(replica_id_in_sync_group) or isinstance(replica_id_in_sync_group, int)): raise ValueError( "replica_id_in_sync_group can only be an integer, a Tensor or None.") self._replica_id_in_sync_group = replica_id_in_sync_group # We need this check becaused TPUContext extends from ReplicaContext and # does not pass a strategy object since it is used by TPUEstimator. if strategy: self._local_replica_id = strategy.extended._get_local_replica_id( replica_id_in_sync_group) self._summary_recording_distribution_strategy = None @doc_controls.do_not_generate_docs def __enter__(self): _push_per_thread_mode(self._thread_context) def replica_id_is_zero(): return math_ops.equal(self.replica_id_in_sync_group, constant_op.constant(0)) summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access self._summary_recording_distribution_strategy = ( summary_state.is_recording_distribution_strategy) summary_state.is_recording_distribution_strategy = replica_id_is_zero @doc_controls.do_not_generate_docs def __exit__(self, exception_type, exception_value, traceback): summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access summary_state.is_recording_distribution_strategy = ( self._summary_recording_distribution_strategy) _pop_per_thread_mode() def merge_call(self, merge_fn, args=(), kwargs=None): """Merge args across replicas and run `merge_fn` in a cross-replica context. This allows communication and coordination when there are multiple calls to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. See `tf.distribute.Strategy.run` for an explanation. If not inside a distributed scope, this is equivalent to: ``` strategy = tf.distribute.get_strategy() with cross-replica-context(strategy): return merge_fn(strategy, *args, **kwargs) ``` Args: merge_fn: Function that joins arguments from threads that are given as PerReplica. It accepts `tf.distribute.Strategy` object as the first argument. args: List or tuple with positional per-thread arguments for `merge_fn`. kwargs: Dict with keyword per-thread arguments for `merge_fn`. Returns: The return value of `merge_fn`, except for `PerReplica` values which are unpacked. """ require_replica_context(self) if kwargs is None: kwargs = {} merge_fn = autograph.tf_convert( merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) return self._merge_call(merge_fn, args, kwargs) def _merge_call(self, merge_fn, args, kwargs): """Default implementation for single replica.""" _push_per_thread_mode( # thread-local, so not needed with multiple threads distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access try: return merge_fn(self._strategy, *args, **kwargs) finally: _pop_per_thread_mode() @property def num_replicas_in_sync(self): """Returns number of replicas that are kept in sync.""" return self._strategy.num_replicas_in_sync @property def replica_id_in_sync_group(self): """Returns the id of the replica. This identifies the replica among all replicas that are kept in sync. The value of the replica id can range from 0 to `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. NOTE: This is not guaranteed to be the same ID as the XLA replica ID use for low-level operations such as collective_permute. Returns: a `Tensor`. """ # It's important to prefer making the Tensor at call time whenever possible. # Keeping Tensors in global states doesn't work well with nested # tf.function, since it's possible that the tensor is generated in one func # graph, and gets captured by another, which will result in a subtle "An op # outside of the function building code is being passed a Graph tensor" # error. Making the tensor at call time to ensure it is the same graph where # it's used. However to be compatible with tpu.replicate(), # self._replica_id_in_sync_group can also be a Tensor. if tensor_util.is_tensor(self._replica_id_in_sync_group): return self._replica_id_in_sync_group return constant_op.constant( self._replica_id_in_sync_group, dtypes.int32, name="replica_id_in_sync_group") @property def _replica_id(self): """This is the local replica id in a given sync group.""" return self._local_replica_id @property def strategy(self): """The current `tf.distribute.Strategy` object.""" return self._strategy @property @deprecation.deprecated(None, "Please avoid relying on devices property.") def devices(self): """Returns the devices this replica is to be executed on, as a tuple of strings. NOTE: For `tf.distribute.MirroredStrategy` and `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a nested list of device strings, e.g, [["GPU:0"]]. """ require_replica_context(self) return (device_util.current(),) def all_reduce(self, reduce_op, value, options=None): """All-reduces `value` across all replicas. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def step_fn(): ... ctx = tf.distribute.get_replica_context() ... value = tf.identity(1.) ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) >>> strategy.experimental_local_results(strategy.run(step_fn)) (, ) It supports batched operations. You can pass a list of values and it attempts to batch them when possible. You can also specify `options` to indicate the desired batching behavior, e.g. batch the values into multiple packs so that they can better overlap with computations. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def step_fn(): ... ctx = tf.distribute.get_replica_context() ... value1 = tf.identity(1.) ... value2 = tf.identity(2.) ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) >>> strategy.experimental_local_results(strategy.run(step_fn)) ([PerReplica:{ 0: , 1: }, PerReplica:{ 0: , 1: }],) Note that all replicas need to participate in the all-reduce, otherwise this operation hangs. Note that if there're multiple all-reduces, they need to execute in the same order on all replicas. Dispatching all-reduce based on conditions is usually error-prone. This API currently can only be called in the replica context. Other variants to reduce values across replicas are: * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API in the cross-replica context. * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and all-reduce API in the cross-replica context. * `tf.distribute.Strategy.reduce`: a more convenient method to reduce to the host in cross-replica context. Args: reduce_op: a `tf.distribute.ReduceOp` value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN". value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts. The structure and the shapes of the `tf.Tensor` need to be same on all replicas. options: a `tf.distribute.experimental.CommunicationOptions`. Options to perform collective operations. This overrides the default options if the `tf.distribute.Strategy` takes one in the constructor. See `tf.distribute.experimental.CommunicationOptions` for details of the options. Returns: A nested structure of `tf.Tensor` with the reduced values. The structure is the same as `value`. """ if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) if options is None: options = collective_util.Options() def batch_all_reduce(strategy, *value_flat): return strategy.extended.batch_reduce_to( reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], options) if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. @custom_gradient.custom_gradient def grad_wrapper(*xs): ys = self.merge_call(batch_all_reduce, args=xs) # The gradient of an all-sum is itself an all-sum (all-mean, likewise). return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) else: # TODO(cjfj): Implement gradients for other reductions. reduced = nest.pack_sequence_as( value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) return nest.map_structure(array_ops.prevent_gradient, reduced) # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient # all-reduce. It would return a function returning the result of reducing `t` # across all replicas. The caller would wait to call this function until they # needed the reduce result, allowing an efficient implementation: # * With eager execution, the reduction could be performed asynchronously # in the background, not blocking until the result was needed. # * When constructing a graph, it could batch up all reduction requests up # to that point that the first result is needed. Most likely this can be # implemented in terms of `merge_call()` and `batch_reduce_to()`. @tf_export("distribute.ReplicaContext", v1=[]) class ReplicaContext(ReplicaContextBase): __doc__ = ReplicaContextBase.__doc__ def all_gather(self, value, axis, options=None): """All-gathers `value` across all replicas along `axis`. Note: An `all_gather` method can only be called in replica context. For a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. All replicas need to participate in the all-gather, otherwise this operation hangs. So if `all_gather` is called in any replica, it must be called in all replicas. Note: If there are multiple `all_gather` calls, they need to be executed in the same order on all replicas. Dispatching `all_gather` based on conditions is usually error-prone. For all strategies except `tf.distribute.TPUStrategy`, the input `value` on different replicas must have the same rank, and their shapes must be the same in all dimensions except the `axis`-th dimension. In other words, their shapes cannot be different in a dimension `d` where `d` does not equal to the `axis` argument. For example, given a `tf.distribute.DistributedValues` with component tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` or `all_gather(..., axis=2, ...)`. However, with `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and same shape. Note: The input `value` must have a non-zero rank. Otherwise, consider using `tf.expand_dims` before gathering them. You can pass in a single tensor to all-gather: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> @tf.function ... def gather_value(): ... ctx = tf.distribute.get_replica_context() ... local_value = tf.constant([1, 2, 3]) ... return ctx.all_gather(local_value, axis=0) >>> result = strategy.run(gather_value) >>> result PerReplica:{ 0: , 1: } >>> strategy.experimental_local_results(result) (, ) You can also pass in a nested structure of tensors to all-gather, say, a list: >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> @tf.function ... def gather_nest(): ... ctx = tf.distribute.get_replica_context() ... value_1 = tf.constant([1, 2, 3]) ... value_2 = tf.constant([[1, 2], [3, 4]]) ... # all_gather a nest of `tf.distribute.DistributedValues` ... return ctx.all_gather([value_1, value_2], axis=0) >>> result = strategy.run(gather_nest) >>> result [PerReplica:{ 0: , 1: }, PerReplica:{ 0: , 1: }] >>> strategy.experimental_local_results(result) ([PerReplica:{ 0: , 1: }, PerReplica:{ 0: , 1: }],) What if you are all-gathering tensors with different shapes on different replicas? Consider the following example with two replicas, where you have `value` as a nested structure consisting of two items to all-gather, `a` and `b`. On Replica 0, `value` is {'a': [0], 'b': [[0, 1]]} On Replica 1, `value` is {'a': [1], 'b': [[2, 3], [4, 5]]} Result for `all_gather` with `axis`=0: (on each of the replicas): {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]} Args: value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, or a `tf.distribute.DistributedValues` instance. The structure of the `tf.Tensor` need to be same on all replicas. The underlying tensor constructs can only be dense tensors with non-zero rank, NOT `tf.IndexedSlices`. axis: 0-D int32 Tensor. Dimension along which to gather. options: a `tf.distribute.experimental.CommunicationOptions`. Options to perform collective operations. This overrides the default options if the `tf.distribute.Strategy` takes one in the constructor. See `tf.distribute.experimental.CommunicationOptions` for details of the options. Returns: A nested structure of `tf.Tensor` with the gathered values. The structure is the same as `value`. """ for v in nest.flatten(value): if isinstance(v, ops.IndexedSlices): raise NotImplementedError("all_gather does not support IndexedSlices") if options is None: options = collective_util.Options() def batch_all_gather(strategy, *value_flat): return strategy.extended._batch_gather_to( # pylint: disable=protected-access [(v, _batch_reduce_destination(v)) for v in value_flat], axis, options) @custom_gradient.custom_gradient def grad_wrapper(*xs): ys = self.merge_call(batch_all_gather, args=xs) def grad(*dy_s): grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) new_grads = [] for i, grad in enumerate(grads): input_shape = array_ops.shape(xs[i]) axis_dim = array_ops.reshape(input_shape[axis], [1]) with ops.control_dependencies([array_ops.identity(grads)]): d = self.all_gather(axis_dim, axis=0) begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) end_dim = begin_dim + array_ops.shape(xs[i])[axis] new_grad = array_ops.gather( grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) new_grads.append(new_grad) return new_grads return ys, grad return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) @tf_export(v1=["distribute.ReplicaContext"]) class ReplicaContextV1(ReplicaContextBase): __doc__ = ReplicaContextBase.__doc__ def _batch_reduce_destination(x): """Returns the destinations for batch all-reduce.""" if isinstance(x, ops.Tensor): # If this is a one device strategy. return x.device else: return x # ------------------------------------------------------------------------------ _creating_default_strategy_singleton = False class _DefaultDistributionStrategyV1(StrategyV1): """Default `tf.distribute.Strategy` if none is explicitly selected.""" def __init__(self): if not _creating_default_strategy_singleton: raise RuntimeError("Should only create a single instance of " "_DefaultDistributionStrategy") super(_DefaultDistributionStrategyV1, self).__init__(_DefaultDistributionExtended(self)) def __deepcopy__(self, memo): del memo raise RuntimeError("Should only create a single instance of " "_DefaultDistributionStrategy") class _DefaultDistributionStrategy(Strategy): """Default `tf.distribute.Strategy` if none is explicitly selected.""" def __init__(self): if not _creating_default_strategy_singleton: raise RuntimeError("Should only create a single instance of " "_DefaultDistributionStrategy") super(_DefaultDistributionStrategy, self).__init__( _DefaultDistributionExtended(self)) def __deepcopy__(self, memo): del memo raise RuntimeError("Should only create a single instance of " "_DefaultDistributionStrategy") class _DefaultDistributionContext(object): """Context manager setting the default `tf.distribute.Strategy`.""" __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] def __init__(self, strategy): def creator(next_creator, **kwargs): _require_strategy_scope_strategy(strategy) return next_creator(**kwargs) self._var_creator_scope = variable_scope.variable_creator_scope(creator) self._strategy = strategy self._nested_count = 0 def __enter__(self): # Allow this scope to be entered if this strategy is already in scope. if distribution_strategy_context.has_strategy(): raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") if self._nested_count == 0: self._var_creator_scope.__enter__() self._nested_count += 1 return self._strategy def __exit__(self, exception_type, exception_value, traceback): self._nested_count -= 1 if self._nested_count == 0: try: self._var_creator_scope.__exit__( exception_type, exception_value, traceback) except RuntimeError as e: six.raise_from( RuntimeError("Variable creator scope nesting error: move call to " "tf.distribute.set_strategy() out of `with` scope."), e) class _DefaultDistributionExtended(StrategyExtendedV1): """Implementation of _DefaultDistributionStrategy.""" def __init__(self, container_strategy): super(_DefaultDistributionExtended, self).__init__(container_strategy) self._retrace_functions_for_each_device = False def _scope(self, strategy): """Context manager setting a variable creator and `self` as current.""" return _DefaultDistributionContext(strategy) def colocate_vars_with(self, colocate_with_variable): """Does not require `self.scope`.""" _require_strategy_scope_extended(self) return ops.colocate_with(colocate_with_variable) def variable_created_in_scope(self, v): return v._distribute_strategy is None # pylint: disable=protected-access def _experimental_distribute_dataset(self, dataset, options): return dataset def _distribute_datasets_from_function(self, dataset_fn, options): return dataset_fn(InputContext()) def _experimental_distribute_values_from_function(self, value_fn): return value_fn(ValueContext()) def _make_dataset_iterator(self, dataset): return _DefaultDistributionExtended.DefaultInputIterator(dataset) def _make_input_fn_iterator(self, input_fn, replication_mode=InputReplicationMode.PER_WORKER): dataset = input_fn(InputContext()) return _DefaultDistributionExtended.DefaultInputIterator(dataset) def _experimental_make_numpy_dataset(self, numpy_input, session): numpy_flat = nest.flatten(numpy_input) vars_flat = tuple( variable_scope.variable(array_ops.zeros(i.shape, i.dtype), trainable=False, use_resource=True) for i in numpy_flat ) for v, i in zip(vars_flat, numpy_flat): numpy_dataset.init_var_from_numpy(v, i, session) vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) return dataset_ops.Dataset.from_tensor_slices(vars_nested) def _broadcast_to(self, tensor, destinations): if destinations is None: return tensor else: raise NotImplementedError("TODO") def _call_for_each_replica(self, fn, args, kwargs): with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): return fn(*args, **kwargs) def _reduce_to(self, reduce_op, value, destinations, options): # TODO(josh11b): Use destinations? del reduce_op, destinations, options return value def _gather_to_implementation(self, value, destinations, axis, options): del destinations, axis, options return value def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): # TODO(josh11b): Figure out what we should be passing to UpdateContext() # once that value is used for something. with UpdateContext(colocate_with): result = fn(*args, **kwargs) if should_group: return result else: return nest.map_structure(self._local_results, result) def read_var(self, replica_local_var): return array_ops.identity(replica_local_var) def _local_results(self, distributed_value): return (distributed_value,) def value_container(self, value): return value @property def _num_replicas_in_sync(self): return 1 @property def worker_devices(self): raise RuntimeError("worker_devices() method unsupported by default " "tf.distribute.Strategy.") @property def parameter_devices(self): raise RuntimeError("parameter_devices() method unsupported by default " "tf.distribute.Strategy.") def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" # Default strategy doesn't indicate multi-worker training. return False @property def should_checkpoint(self): return True @property def should_save_summary(self): return True def _get_local_replica_id(self, replica_id_in_sync_group): return replica_id_in_sync_group def _get_replica_id_in_sync_group(self, replica_id): return replica_id # TODO(priyag): This should inherit from `InputIterator`, once dependency # issues have been resolved. class DefaultInputIterator(object): """Default implementation of `InputIterator` for default strategy.""" def __init__(self, dataset): self._dataset = dataset if eager_context.executing_eagerly(): self._iterator = dataset_ops.make_one_shot_iterator(dataset) else: self._iterator = dataset_ops.make_initializable_iterator(dataset) def get_next(self): return self._iterator.get_next() def get_next_as_optional(self): return self._iterator.get_next_as_optional() @deprecated(None, "Use the iterator's `initializer` property instead.") def initialize(self): """Initialize underlying iterators. Returns: A list of any initializer ops that should be run. """ if eager_context.executing_eagerly(): self._iterator = self._dataset.make_one_shot_iterator() return [] else: return [self._iterator.initializer] @property def initializer(self): """Returns a list of ops that initialize the iterator.""" return self.initialize() # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): """Global and per-replica batching are equivalent for this strategy.""" return True class _DefaultReplicaContext(ReplicaContext): """ReplicaContext for _DefaultDistributionStrategy.""" @property def replica_id_in_sync_group(self): # Return 0 instead of a constant tensor to avoid creating a new node for # users who don't use distribution strategy. return 0 # ------------------------------------------------------------------------------ # We haven't yet implemented deserialization for DistributedVariables. # So here we catch any attempts to deserialize variables # when using distribution strategies. # pylint: disable=protected-access _original_from_proto = resource_variable_ops._from_proto_fn def _from_proto_fn(v, import_scope=None): if distribution_strategy_context.has_strategy(): raise NotImplementedError( "Deserialization of variables is not yet supported when using a " "tf.distribute.Strategy.") else: return _original_from_proto(v, import_scope=import_scope) resource_variable_ops._from_proto_fn = _from_proto_fn # pylint: enable=protected-access #------------------------------------------------------------------------------- # Shorthand for some methods from distribution_strategy_context. _push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access _get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access _pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access _get_default_replica_mode = ( distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access # ------------------------------------------------------------------------------ # Metrics to track which distribution strategy is being called distribution_strategy_gauge = monitoring.StringGauge( "/tensorflow/api/distribution_strategy", "Gauge to track the type of distribution strategy used.", "TFVersion") distribution_strategy_replica_gauge = monitoring.IntGauge( "/tensorflow/api/distribution_strategy/replica", "Gauge to track the number of replica each distribution strategy used.", "CountType")