# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for cross_device_ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import threading from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nccl_ops from tensorflow.python.platform import tf_logging as logging INSTANCE_KEY_START_NUMBER = 100 def aggregate_gradients_using_nccl(replica_grads): """Aggregate gradients using nccl allreduce.""" agg_all_g_and_v = [] for single_g_and_v in zip(*replica_grads): single_grads = [g for g, _ in single_g_and_v] agg_grads = nccl_ops.all_sum(single_grads) agg_all_g_and_v.append( [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) agg_all_g_and_v = list(zip(*agg_all_g_and_v)) return agg_all_g_and_v def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads): """Aggregate gradients using hierarchical copies. Args: avail_devices: available GPU devices. replica_grads: List of lists of (gradient, variable) tuples. The outer list is over replicas. The inner list is over individual gradients. Returns: The list of (aggregated_gradient, variable), where the gradient has been summed across all replicas and the variable is chosen from the first replica. """ # This only works for DGX-1 type of machine topology # Device peer to peer matrix # DMA: 0 1 2 3 4 5 6 7 # 0: Y Y Y Y Y N N N # 1: Y Y Y Y N Y N N # 2: Y Y Y Y N N Y N # 3: Y Y Y Y N N N Y # 4: Y N N N Y Y Y Y # 5: N Y N N Y Y Y Y # 6: N N Y N Y Y Y Y # 7: N N N Y Y Y Y Y agg_grads = [] num_devices = len(avail_devices) # In the special case of DGX-1 machine topology, the two groups have equal # size. group_size = num_devices // 2 for i, single_grads in enumerate(zip(*replica_grads)): group_0_main_device = i % num_devices group_1_main_device = (group_0_main_device + group_size) % num_devices if group_0_main_device < group_size: group_0_begin = 0 group_1_begin = group_size else: group_0_begin = group_size group_1_begin = 0 # Aggregate the first group. group_0_device_grads = single_grads[group_0_begin: group_0_begin + group_size] with ops.device(avail_devices[group_0_main_device]): group_0_agg_grads, _ = aggregate_single_gradient_using_copy( group_0_device_grads, False, False) # Aggregate the second group. group_1_device_grads = single_grads[group_1_begin: group_1_begin + group_size] with ops.device(avail_devices[group_1_main_device]): group_1_agg_grads, _ = aggregate_single_gradient_using_copy( group_1_device_grads, False, False) # Aggregate between the groups. with ops.device(avail_devices[group_0_main_device]): (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( [group_0_agg_grads, group_1_agg_grads], False, False) # Broadcast the result back into the root of each group. with ops.device(avail_devices[group_0_main_device]): group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) with ops.device(avail_devices[group_1_main_device]): group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) agg_grads_bcast = [] for j in range(len(single_grads)): with ops.device(avail_devices[j]): # Broadcast the result back to each member in the group from the root. if (group_0_main_device < group_size) == (j < group_size): src_device_grad = group_0_agg_grads_bcast else: src_device_grad = group_1_agg_grads_bcast agg_grads_bcast.append(array_ops.identity(src_device_grad)) agg_grads.append( [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) agg_grads = list(zip(*agg_grads)) return agg_grads def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, check_inf_nan): """Calculate the average gradient for a shared variable across all replicas. Note that this function provides a synchronization point across all replicas. Args: grad_and_vars: A list or tuple of (gradient, variable) tuples. Each (gradient, variable) pair within the outer list represents the gradient of the variable calculated for a single replica, and the number of pairs equals the number of replicas. use_mean: if True, mean is taken, else sum of gradients is taken. check_inf_nan: check grads for nans and infs. Returns: The tuple ([(average_gradient, variable),], has_nan_or_inf) where the gradient has been averaged across all replicas. The variable is chosen from the first replica. The has_nan_or_inf indicates the grads has nan or inf. """ grads = [g for g, _ in grad_and_vars] grad = math_ops.add_n(grads) if use_mean and len(grads) > 1: grad = array_ops.multiply(grad, 1.0 / len(grads)) v = grad_and_vars[0][1] if check_inf_nan: has_nan_or_inf = array_ops.logical_not( array_ops.reduce_all(array_ops.is_finite(grads))) return (grad, v), has_nan_or_inf else: return (grad, v), None # TODO(yuefengz): use random key starts to avoid reusing keys? class CollectiveKeys(object): """Class that manages collective keys. We need to manage three different keys for collective: *Group key*: an integer key to identify the set of cooperative devices. Collective ops work under the same set of devices must using the same group key. *Instance key*: an integer key to identify the set of same counterpart of tensors on different devices in a device group that need to be all-reduced. This class is thread safe. """ def __init__(self, group_key_start=1): """Initializes the object. Args: group_key_start: the starting integer of group key. """ self._group_key = group_key_start self._group_key_table = {} self._instance_key_table = {} self._lock = threading.Lock() def get_group_key(self, devices): """Returns a group key for the set of devices. Args: devices: a list of canonical device strings in a collective group. Returns: int key uniquely identifying the set of device names. """ key_id = hash(tuple(sorted(devices))) with self._lock: if key_id not in self._group_key_table: new_key = self._group_key self._group_key += 1 self._group_key_table[key_id] = new_key self._instance_key_table[new_key] = {} for device in devices: self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER return self._group_key_table[key_id] def get_instance_key(self, group_key, device): """Returns a new instance key for use in defining a collective op. You should call this once per each collective op of a collective instance. Args: group_key: the group key returned by get_group_key(). You should not assign the group key yourself. device: a canonical device string. It should be the device this collective op is on. Returns: a new instance key. Raises: ValueError: when the group key is invalid or the device is not in the group. """ with self._lock: group = self._instance_key_table.get(group_key, None) if group is None: raise ValueError('group {} not found'.format(group_key)) if device not in group: raise ValueError('{} not in group {}'.format(device, group_key)) v = group[device] group[device] += 1 return v def __deepcopy__(self, memo): # distribute_coordinator deep-copies the strategy object, so # CollectiveKeys needs to support deep copy as well. copied = CollectiveKeys() copied._group_key = self._group_key copied._group_key_table = copy.deepcopy(self._group_key_table, memo) copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo) return copied class CollectiveReplicaLauncher(object): """Launch collectives on one replica.""" _use_scoped_allocator = True _use_collective_v2 = False def __init__(self, group_key, group_size, collective_keys, device, executor=None): if executor and not executor.is_async(): raise ValueError('executor must be async') self._group_key = group_key self._group_size = group_size self._collective_keys = collective_keys self._device = device self._executor = executor def _executor_scope(self): if context.executing_eagerly() and not self._executor: raise ValueError('collectives requires a async executor in eager mode') if context.executing_eagerly(): return context.executor_scope(self._executor) return ops.NullContextmanager() def _control_input(self, control_input): if control_input is not None: return ops.control_dependencies([control_input]) return ops.NullContextmanager() def _should_use_collective_v2(self): if not CollectiveReplicaLauncher._use_collective_v2: return False if not ops.executing_eagerly_outside_functions(): return False return True def _next_instance_key(self): """Returns the next instance key.""" if self._should_use_collective_v2(): # Assigning instance keys at function building time have issues since # different workers may retrace the function at different times. With # collective V2 we can use capture_call_time_value to use a placeholder as # the instance key and feed it at function call time. In this way we also # don't reuse instance keys, which allows for per-instance cancellation. graph = ops.get_default_graph() # Control flow ops don't work with capture_call_time_value, so we put the # capture in the function graph of that control flow op. while getattr(graph, 'is_control_flow_graph', False): graph = graph.outer_graph if not context.executing_eagerly() and graph.building_function: with graph.as_default(): # Capture self._next_instance_key so that when building a function # that calls another tf.function, the instance key assignment is # further delayed until we actually call the function in eager. Note # that capture_call_time_value doesn't automatically propagate the # deferred capture to the outer function. return graph.capture_call_time_value( self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32)) else: instance_key = self._collective_keys.get_instance_key( self._group_key, self._device) with ops.device('CPU:0'): return ops.convert_to_tensor(instance_key, dtype=dtypes.int32) else: return self._collective_keys.get_instance_key(self._group_key, self._device) def all_reduce(self, input_tensor, control_input=None, communication_hint='AUTO', timeout=0): """All-reduce a dense tensor. This can be called in eager mode if a async executor is supplied when creating the launcher. Args: input_tensor: a dense tensor. It must have the same shape on all replicas. control_input: if not None, add control edges between control_input and the all-reduce. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: The reduced tensor. """ instance_key = self._next_instance_key() with self._executor_scope(), \ ops.device(self._device), \ self._control_input(control_input): if self._should_use_collective_v2(): return collective_ops.all_reduce_v2( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=communication_hint, timeout=timeout) else: return collective_ops.all_reduce( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=communication_hint, timeout=timeout) def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. This can be called in eager mode if an async executor is supplied when creating the launcher. Args: input_tensor: a dense tensor. It must have the same shape on all replicas. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: The reduced tensor. """ instance_key = self._next_instance_key() with self._executor_scope(), ops.device(self._device): if self._should_use_collective_v2(): return collective_ops.all_gather_v2( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=communication_hint, timeout=timeout) else: return collective_ops.all_gather( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=communication_hint, timeout=timeout) def batch_all_reduce(self, input_tensor_packs, communication_hint='AUTO', timeout=0): """Batch all-reduce dense tensors. This takes a list of batches of tensors. Using multiple batches have the benefit that it doesn't need to wait for all inputs to be ready to start the all-reduce. This can be called in eager mode if a async executor is supplied when creating the launcher. Args: input_tensor_packs: a list of lists of dense tensors. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: A flat list of reduced tensors. """ # We don't batch with concat in eager. It's easy to get it wrong because # we need to avoid any numpy() calls on values produced by the async # executor. This effectively disables batching in eager, but it's unlikely # to all-reduce a large number of tensors in eager. batch_with_concat = (not self._use_scoped_allocator and not context.executing_eagerly()) outputs = [] for pack in input_tensor_packs: # TODO(b/169168846): inserts a parallel all_gather to verify packings # are the same on each replica. if batch_with_concat: with ops.device(self._device): flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] shapes = [array_ops.shape(t) for t in pack] if communication_hint == 'NCCL' and outputs: control_input = outputs[-1] else: control_input = None reduced = self.all_reduce( array_ops.concat(flat_tensors, axis=0), control_input, communication_hint, timeout) num_elements = [math_ops.reduce_prod(s) for s in shapes] flat_outputs = array_ops.split(reduced, num_elements, axis=0) for shape, flat_output in zip(shapes, flat_outputs): outputs.append(array_ops.reshape(flat_output, shape)) else: # By placing all CollectiveReduce ops in a batch under single name # scope, we ensure they will be picked up by the `ScopedAllocator` # grappler optimizer and packed into a single all-reduce. with ops.name_scope('allreduce'): for input_tensor in pack: if communication_hint == 'NCCL' and outputs: control_input = outputs[-1] else: control_input = None outputs.append( self.all_reduce(input_tensor, control_input, communication_hint, timeout)) return outputs def all_gather(self, input_tensor, axis, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. This method must be called inside a tf.function. Args: input_tensor: a dense tensor. It must have the same rank on all replicas, and dimensions other than `axis` need to be the same as well. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). communication_hint: string providing hint to runtime for choosing collective implementation. Available options are `AUTO`, `NCCL`, and `RING`. timeout: a float. The timeout in seconds. Returns: The gathered Tensor. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError('all_gather in eager mode is not supported') with ops.device(self._device), \ ops.control_dependencies([array_ops.identity(input_tensor)]): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = self._all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = self._all_gather( padded_input_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] for i in range(self._group_size): start_pos = i * full_axis_dim split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) return array_ops.transpose(out_tensor_t, perm=perm_after) def all_reduce_indexed_slices(self, input_slices, communication_hint='AUTO', timeout=0): """All-reduce an IndexedSlices. This method must be called inside a tf.function. Args: input_slices: an IndexedSlices. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: The reduced IndexedSlices. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError( 'all_reduce_indexed_slices in eager mode is not supported') # Current CollectiveAllGather implementations require input IndexedSlices to # have consistent length across the board, we handle the reduction of # IndexedSlices as follows: # 1. Gather the lengths of IndexedSlices from all participants. # 2. If they have consistent length, apply all_gather. # 3. Otherwise convert IndexedSlices to dense tensors and apply # all_reduce. with ops.device(self._device): def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" all_values = self._all_gather( input_slices.values, communication_hint, timeout=timeout) # Add control dependency to order the all-gather. control = [all_values] if communication_hint == 'NCCL' else [] with ops.control_dependencies(control): all_indices = self._all_gather( input_slices.indices, communication_hint, timeout=timeout) return ops.IndexedSlices( values=all_values, indices=all_indices, dense_shape=input_slices.dense_shape) def densify_and_all_reduce(): """Use all_reduce to aggregate `IndexedSlices`.""" densified = ops.convert_to_tensor(input_slices) reduced = self.all_reduce( densified, communication_hint=communication_hint, timeout=timeout) # We have to convert dense grad to IndexedSlice because all_reduce() # and all_gather() must have the same return type as required by # control_flow_ops.cond. return ops.IndexedSlices( values=reduced, indices=math_ops.range(array_ops.shape(reduced)[0]), dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) all_lengths = self._all_gather( length, communication_hint, timeout=timeout) return control_flow_ops.cond( math_ops.equal( math_ops.reduce_max(all_lengths), math_ops.reduce_min(all_lengths)), all_gather, densify_and_all_reduce) def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" if any(isinstance(v, ops.IndexedSlices) for v in values): return backprop.aggregate_indexed_slices_gradients(values) else: return accumulation_fn(values) def divide_by_n_tensors_or_indexed_slices(value, n): if isinstance(value, ops.IndexedSlices): value = backprop.flatten_nested_indexed_slices(value) return ops.IndexedSlices( value.values / n, value.indices, value.dense_shape) else: return value / n def copy_tensor_or_indexed_slices_to_device(value, device): with ops.device(device): if isinstance(value, ops.IndexedSlices): copied_values = array_ops.identity(value.values) copied_indices = array_ops.identity(value.indices) copied_shape = array_ops.identity(value.dense_shape) result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) else: result = array_ops.identity(value) return result def is_indexed_slices(value): if isinstance(value, ops.IndexedSlices): return True assert isinstance(value, value_lib.DistributedValues) return all(isinstance(v, ops.IndexedSlices) for v in value.values) def split_by_sparsity(values): """Split values into dense and sparse values. Args: values: a list of tensors or `PerReplica`s. Returns: Four lists: a list of dense values, a list of their indices in `values` and a list of sparse values, a list of their indices in `values`. """ dense_values = [] dense_indices = [] sparse_values = [] sparse_indices = [] for i, v in enumerate(values): if is_indexed_slices(v): sparse_values.append(v) sparse_indices.append(i) else: dense_values.append(v) dense_indices.append(i) return dense_values, dense_indices, sparse_values, sparse_indices def stitch_values(values_and_indices_list): """Stitch values together according to their indices. Args: values_and_indices_list: a list of tuples of values and indices indicating the values and positions in the returned list. Returns: a stitched list of values. """ length = 0 for values_and_indices in values_and_indices_list: length += len(values_and_indices[0]) result = [None] * length for values_and_indices in values_and_indices_list: if values_and_indices and values_and_indices[0]: for v, i in zip(*values_and_indices): assert result[i] is None result[i] = v return result def group_by_size(input_tensors, bytes_per_pack): """Groups `input_tensors` into chunks of `bytes_per_pack`. The method preserves the original order of `input_tensors`. The grouping is best effort, each pack could have more or less bytes than `bytes_per_pack`. It only groups values with known shape. Args: input_tensors: a list of Tensor. bytes_per_pack: an integer. Returns: A list of packs of Tensor. All values are grouped into one pack if `bytes_per_pack` is zero or any of the value has unknown shape. """ if bytes_per_pack == 0: return [input_tensors] packs = [] last_pack_size = 0 for value in input_tensors: num_elements = value.shape.num_elements() if num_elements is None: # Can't pack values with unknown shape. logging.warning( 'not packing values due to the unknown or inconsistent shape of %s', value) return [input_tensors] size = num_elements * value.dtype.size # Try to keep each pack as close to bytes_per_pack as possible, while each # pack is at least bytes_per_pack large. I.E. we err on the side of having # few but large packs. if not packs or last_pack_size > bytes_per_pack: packs.append([]) last_pack_size = 0 packs[-1].append(value) last_pack_size += size return packs def _pad_util(input_tensor, full_axis_dim): """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] tensor_rank = array_ops.rank(input_tensor) paddings_axis = [[0, missing_axis_dim]] paddings = array_ops.concat([ paddings_axis, array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) ], axis=0) padded_input_tensor = array_ops.pad(input_tensor, paddings) return padded_input_tensor