# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Class implementing utilities used by tf.distribute.Strategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.distribute import tpu_values as tpu_values_lib from tensorflow.python.distribute import values as values_lib from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import nest def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False): """Makes a nest per-replica into a nest of PerReplica/Mirrored values. Args: values: Values to regroup wrap_class: Class that `values` be wrapped in. always_wrap: Always wrap the `values` in `wrap_class` even if the values are the same except for DistributeVariable. Returns: Wrapped `values`. """ v0 = values[0] if isinstance(v0, list): for v in values[1:]: assert isinstance(v, list) assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % (len(v), len(v0), v, v0)) return [ regroup(tuple(v[i] for v in values), wrap_class, always_wrap) for i in range(len(v0)) ] if isinstance(v0, tuple): for v in values[1:]: assert isinstance(v, tuple) assert len(v) == len(v0) regrouped_tuple = tuple( regroup(tuple(v[i] for v in values), wrap_class, always_wrap) for i in range(len(v0))) if hasattr(v0, "_fields"): # This tuple is in fact a namedtuple! Create a new namedtuple instance # and initialize it with the regrouped values: assert hasattr(v0, "_make") return v0._make(regrouped_tuple) else: return regrouped_tuple if isinstance(v0, dict): v0keys = v0.keys() for v in values[1:]: assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v)) assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % (set(v0keys), set(v.keys()))) # Use the actual type in case it is a class inherited from a dict. return type(v0)({ key: regroup(tuple(v[key] for v in values), wrap_class, always_wrap) for key in v0keys }) # If exactly the same object across all devices, return it unwrapped. same_id = True for v in values[1:]: if v is not v0: same_id = False break # Consider three cases where same_id is true: # * If v0 is a DistributedVariable (a MirroredVariable or # SyncOnReadVariable, and same_id means it is the same across all # devices), we want to return it. We check DistributedVariable # specifically since it can look like it has a # _distributed_container member since its members do. if same_id and isinstance(v0, values_lib.DistributedVariable): return v0 # * If v0 is a member of a distributed variable, in which case # hasattr(v0, "_distributed_container") is true, we want to # return the DistributedVariable that contains it using the # _distributed_container logic below. This case can trigger # same_id when there is only one device. # * In any other situation, same_id means we return v0 unless `always_wrap` is # true. if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): return v0 # Detect the case where each device has a parallel component of the # same MirroredVariable (or SyncOnReadVariable). In this case we # want to return the containing MirroredVariable, after a bunch of # sanity checking. In particular, each component should have the # same container, and the devices of the variables should match the # keys of the per-replica dictionary. if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, values_lib.MirroredVariable), ( "ids = %s, values = %s" % ([id(v) for v in values], values)) distributed_container = v0._distributed_container() assert distributed_container is not None for v in values[1:]: assert distributed_container is v._distributed_container() return distributed_container # pylint: enable=protected-access return wrap_class(values) def select_replica(replica_id, structured): """Specialize a nest of regular & per-replica values for one replica.""" def _get(x): # `DistributedValues` would be sliced according to replica unless it is a # `DistributedVariable` because `DistributedVariable` can be handled # directly in the replica context. if (isinstance(x, values_lib.DistributedVariable) or not isinstance(x, values_lib.DistributedValues)): return x else: return x.values[replica_id] return nest.map_structure(_get, structured) def select_replica_mirrored(replica_id, structured): """Specialize a nest of regular & mirrored values for one replica.""" def _get_mirrored(x): if isinstance(x, values_lib.DistributedValues): if not is_mirrored(x): raise TypeError( "Expected value to be mirrored across replicas: %s in %s." % (x, structured)) packed_var = getattr(x, "_packed_variable", None) if packed_var is not None: return packed_var return x.values[replica_id] else: return x return nest.map_structure(_get_mirrored, structured) def update_regroup(extended, updates, group): """Regroup for an update, with dependencies to ensure all updates execute.""" if not group: regrouped = regroup(updates, values_lib.Mirrored) return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access def _make_grouped_mirrored(values): """Convert per-replica list `values` into Mirrored type with grouping.""" if len(values) == 1: return values_lib.Mirrored(values) # Make sure we run all updates. Without this, something like # session.run(extended.update(...)) may only update one replica. g = control_flow_ops.group(values) # If values is just ops, the grouping is enough. Everything in values # should have the same type, since we expect every replica to be performing # the same computation. if not all(tensor_util.is_tensor(v) for v in values): return g # Otherwise we need tensors with the same values as `values`, but # that have a dependency on `g`. with_dep = [] for v in values: with ops.device(v.device), ops.control_dependencies([g]): with_dep.append(array_ops.identity(v)) return values_lib.Mirrored(with_dep) return regroup(updates, _make_grouped_mirrored) def value_container(val): """Returns the container that this per-replica `value` belongs to. Args: val: A value returned by `call_for_each_replica()` or a variable created in `scope()`. Returns: A container that `value` belongs to. If value does not belong to any container (including the case of container having been destroyed), returns the value itself. """ if (hasattr(val, "_distributed_container") and # DistributedVariable has _distributed_container defined # but we don't want to return it. not isinstance(val, values_lib.DistributedVariable)): container = val._distributed_container() # pylint: disable=protected-access if container is not None: return container return val def is_distributed_variable(v): """Determine if a variable is ds variable or TPU mirrored variable.""" return isinstance(v, values_lib.DistributedVariable) def _validate_colocate_extended(v, extended): variable_strategy = v._distribute_strategy # pylint: disable=protected-access if variable_strategy.extended is not extended: raise ValueError( "`colocate_vars_with` must only be passed a variable created in this " "tf.distribute.Strategy.scope(), not %s created in scope: %s" % (v, variable_strategy)) def validate_colocate_distributed_variable(v, extended): if not isinstance(v, values_lib.DistributedVariable): raise ValueError( "`colocate_vars_with` must only be passed a variable created in this " "tf.distribute.Strategy.scope(), not: %r" % (v,)) _validate_colocate_extended(v, extended) def validate_colocate(v, extended): if not hasattr(v, "_distribute_strategy"): raise ValueError( "`colocate_vars_with` must only be passed a variable created in this " "tf.distribute.Strategy.scope(), not: %r" % (v,)) _validate_colocate_extended(v, extended) # Variable creation function for sync strategies. def _get_and_validate_synchronization(kwargs): """Validate that given synchronization value is valid.""" synchronization = kwargs.get("synchronization", vs.VariableSynchronization.AUTO) if synchronization == vs.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not supported with " "tf.distribute strategy. Please change the `synchronization` for " "variable: " + str(kwargs["name"])) if synchronization not in (vs.VariableSynchronization.ON_READ, vs.VariableSynchronization.ON_WRITE, vs.VariableSynchronization.AUTO): raise ValueError( "Invalid variable synchronization mode: %s for variable: %s" % (synchronization, kwargs["name"])) return synchronization def _validate_aggregation(kwargs): aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in (vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, vs.VariableAggregation.ONLY_FIRST_REPLICA): raise ValueError("Invalid variable aggregation mode: %s for variable: %s" % (aggregation, kwargs["name"])) return aggregation def _get_variable_policy_class(synchronization, aggregation, policy_mapping): if synchronization == vs.VariableSynchronization.AUTO: if aggregation == vs.VariableAggregation.NONE: # Use AutoPolicy. return policy_mapping.get(synchronization) else: # Revert to OnWritePolicy return policy_mapping.get(vs.VariableSynchronization.ON_WRITE) return policy_mapping.get(synchronization) def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, policy_mapping, **kwargs): """Create distributed variables with given synchronization and aggregation.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] synchronization = _get_and_validate_synchronization(kwargs) aggregation = _validate_aggregation(kwargs) use_var_policy = getattr(strategy.extended, "_use_var_policy", False) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): value_list = real_mirrored_creator(**kwargs) if use_var_policy: var_policy_cls = _get_variable_policy_class(synchronization, aggregation, policy_mapping) var_policy = var_policy_cls(aggregation=aggregation) var_cls = class_mapping.get("VariableClass") result = var_cls(strategy, value_list, aggregation, var_policy=var_policy) else: var_cls = class_mapping.get(synchronization) result = var_cls(strategy, value_list, aggregation) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for value in value_list: for i, trainable_variable in enumerate(l): if value is trainable_variable: del l[i] break g.add_to_collections(var_collections, result) elif ops.GraphKeys.GLOBAL_STEP in var_collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result # Utility functions # Return True if the Value is Mirrored or the Variable is replicated and kept in # sync. def is_mirrored(val): if isinstance(val, values_lib.DistributedVariable): if val._policy: # pylint: disable=protected-access return val._policy._is_mirrored() # pylint: disable=protected-access return isinstance(val, values_lib.Mirrored) def is_sync_on_read(val): if isinstance(val, values_lib.DistributedVariable): if val._policy: # pylint: disable=protected-access return not val._policy._is_mirrored() # pylint: disable=protected-access return not isinstance(val, values_lib.Mirrored) # The following mapping indicates the policy that you must use for a given # variable `synchronization` and `aggregation` pair. # AutoPolicy is used for: # (synchronization=Auto, aggregation=None) # OnWritePolicy is used for: # (synchronization=Auto, aggregation=SUM,MEAN,ONLY_FIRST_REPLICA) # (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) # OnReadPolicy is used for: # (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA) VARIABLE_POLICY_MAPPING = { vs.VariableSynchronization.AUTO: values_lib.AutoPolicy, vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy, vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy, } VARIABLE_CLASS_MAPPING = { "VariableClass": values_lib.DistributedVariable, vs.VariableSynchronization.AUTO: values_lib.MirroredVariable, vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable, } TPU_VARIABLE_POLICY_MAPPING = { vs.VariableSynchronization.AUTO: tpu_values_lib.TPUAutoPolicy, vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy, vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy, } TPU_VARIABLE_CLASS_MAPPING = { "VariableClass": tpu_values_lib.TPUDistributedVariable, vs.VariableSynchronization.AUTO: tpu_values_lib.TPUMirroredVariable, vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable, vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable, }