# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Mid level API for TPU Embeddings.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import functools from typing import Any, Dict, Callable, List, Optional, Text, Tuple from absl import logging from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.saved_model import save_context from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu_embedding_v2_utils from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.training.saving import saveable_hook from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import tracking from tensorflow.python.types import core from tensorflow.python.types import internal as internal_types from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export _HOOK_KEY = "TPUEmbedding_saveable" _NAME_KEY = "_tpu_embedding_layer" # TODO(bfontain): Cleanup and remove this once there is an implementation of # sharded variables that can be used in the PSStrategy with optimizers. # We implement just enough of the of a tf.Variable so that this could be passed # to an optimizer. class TPUShardedVariable(sharded_variable.ShardedVariableMixin): """A ShardedVariable class for TPU.""" @property def _in_graph_mode(self): return self.variables[0]._in_graph_mode # pylint: disable=protected-access @property def _unique_id(self): return self.variables[0]._unique_id # pylint: disable=protected-access @property def _distribute_strategy(self): return self.variables[0]._distribute_strategy # pylint: disable=protected-access @property def _shared_name(self): return self._name def _add_key_attr(op, name): op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access @tf_export("tpu.experimental.embedding.TPUEmbedding") class TPUEmbedding(tracking.AutoTrackable): """The TPUEmbedding mid level API. NOTE: When instantiated under a TPUStrategy, this class can only be created once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to re-initialize the embedding engine you must re-initialize the tpu as well. Doing this will clear any variables from TPU, so ensure you have checkpointed before you do this. If a further instances of the class are needed, set the `initialize_tpu_embedding` argument to `False`. This class can be used to support training large embeddings on TPU. When creating an instance of this class, you must specify the complete set of tables and features you expect to lookup in those tables. See the documentation of `tf.tpu.experimental.embedding.TableConfig` and `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete set of options. We will cover the basic usage here. NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object, allowing different features to share the same table: ```python table_config_one = tf.tpu.experimental.embedding.TableConfig( vocabulary_size=..., dim=...) table_config_two = tf.tpu.experimental.embedding.TableConfig( vocabulary_size=..., dim=...) feature_config = { 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( table=table_config_one), 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( table=table_config_one), 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( table=table_config_two)} ``` There are two modes under which the `TPUEmbedding` class can used. This depends on if the class was created under a `TPUStrategy` scope or not. Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and `apply_gradients`. We will show examples below of how to use these to train and evaluate your model. Under CPU, we only access to the `embedding_tables` property which allow access to the embedding tables so that you can use them to run model evaluation/prediction on CPU. First lets look at the `TPUStrategy` mode. Initial setup looks like: ```python strategy = tf.distribute.TPUStrategy(...) with strategy.scope(): embedding = tf.tpu.experimental.embedding.TPUEmbedding( feature_config=feature_config, optimizer=tf.tpu.experimental.embedding.SGD(0.1)) ``` When creating a distributed dataset that is to be passed to the enqueue operation a special input option must be specified: ```python distributed_dataset = ( strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) dataset_iterator = iter(distributed_dataset) ``` NOTE: All batches passed to the layer must have the same batch size for each input, more over once you have called the layer with one batch size all subsequent calls must use the same batch_size. In the event that the batch size cannot be automatically determined by the enqueue method, you must call the build method with the batch size to initialize the layer. To use this API on TPU you should use a custom training loop. Below is an example of a training and evaluation step: ```python @tf.function def training_step(dataset_iterator, num_steps): def tpu_step(tpu_features): with tf.GradientTape() as tape: activations = embedding.dequeue() tape.watch(activations) model_output = model(activations) loss = ... # some function of labels and model_output embedding_gradients = tape.gradient(loss, activations) embedding.apply_gradients(embedding_gradients) # Insert your model gradient and optimizer application here for _ in tf.range(num_steps): embedding_features, tpu_features = next(dataset_iterator) embedding.enqueue(embedding_features, training=True) strategy.run(tpu_step, args=(embedding_features, )) @tf.function def evalution_step(dataset_iterator, num_steps): def tpu_step(tpu_features): activations = embedding.dequeue() model_output = model(activations) # Insert your evaluation code here. for _ in tf.range(num_steps): embedding_features, tpu_features = next(dataset_iterator) embedding.enqueue(embedding_features, training=False) strategy.run(tpu_step, args=(embedding_features, )) ``` NOTE: The calls to `enqueue` have `training` set to `True` when `embedding.apply_gradients` is used and set to `False` when `embedding.apply_gradients` is not present in the function. If you don't follow this pattern you may cause an error to be raised or the tpu may deadlock. In the above examples, we assume that the user has a dataset which returns a tuple where the first element of the tuple matches the structure of what was passed as the `feature_config` argument to the object initializer. Also we utilize `tf.range` to get a `tf.while_loop` in order to increase performance. When checkpointing your model, you should include your `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a trackable object and saving it will save the embedding tables and their optimizer slot variables: ```python checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) checkpoint.save(...) ``` On CPU, only the `embedding_table` property is usable. This will allow you to restore a checkpoint to the object and have access to the table variables: ```python model = model_fn(...) embedding = tf.tpu.experimental.embedding.TPUEmbedding( feature_config=feature_config, batch_size=1024, optimizer=tf.tpu.experimental.embedding.SGD(0.1)) checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) checkpoint.restore(...) tables = embedding.embedding_tables ``` You can now use table in functions like `tf.nn.embedding_lookup` to perform your embedding lookup and pass to your model. """ def __init__( self, feature_config: Any, optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access pipeline_execution_with_tensor_core: bool = False): """Creates the TPUEmbedding mid level API object. ```python strategy = tf.distribute.TPUStrategy(...) with strategy.scope(): embedding = tf.tpu.experimental.embedding.TPUEmbedding( feature_config=tf.tpu.experimental.embedding.FeatureConfig( table=tf.tpu.experimental.embedding.TableConfig( dim=..., vocabulary_size=...))) ``` Args: feature_config: A nested structure of `tf.tpu.experimental.embedding.FeatureConfig` configs. optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`, `tf.tpu.experimental.embedding.Adagrad` or `tf.tpu.experimental.embedding.Adam`. When not created under TPUStrategy may be set to None to avoid the creation of the optimizer slot variables, useful for optimizing memory consumption when exporting the model for serving where slot variables aren't needed. pipeline_execution_with_tensor_core: If True, the TPU embedding computations will overlap with the TensorCore computations (and hence will be one step old). Set to True for improved performance. Raises: ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD, Adam or Adagrad) or None when created under a TPUStrategy. """ self._strategy = distribution_strategy_context.get_strategy() self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)) self._pipeline_execution_with_tensor_core = ( pipeline_execution_with_tensor_core) self._feature_config = feature_config # The TPU embedding ops are slightly inconsistent with how they refer to # tables: # * The enqueue op takes a parallel list of tensors for input, one of those # is the table id for the feature which matches the integer index of the # table in the proto created by _create_config_proto(). # * The recv_tpu_embedding_activations op emits lookups per table in the # order from the config proto. # * The send_tpu_embedding_gradients expects input tensors to be per table # in the same order as the config proto. # * Per optimizer load and retrieve ops are specified per table and take the # table name rather than the table id. # Thus we must fix a common order to tables and ensure they have unique # names. # Set table order here self._table_config = list( {feature.table for feature in nest.flatten(feature_config)}) # Ensure tables have unique names. Also error check the optimizer as we # specifically don't do that in the TableConfig class to allow high level # APIs that are built on this to use strings/other classes to represent # optimizers (before they are passed to this class). table_names = [] for i, table in enumerate(self._table_config): if table.optimizer is None: # TODO(bfontain) Should we allow some sort of optimizer merging here? table.optimizer = optimizer if ((table.optimizer is not None or self._using_tpu) and not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access raise ValueError("{} is an unsupported optimizer class. Please pass an " "instance of one of the optimizer classes under " "tf.tpu.experimental.embedding.".format( type(table.optimizer))) if table.name is None: table.name = "table_{}".format(i) if table.name in table_names: raise ValueError("Multiple tables with name {} found.".format( table.name)) table_names.append(table.name) if self._using_tpu: # Extract a list of callable learning rates also in fixed order. Each # table in the confix proto will get a index into this list and we will # pass this list in the same order after evaluation to the # send_tpu_embedding_gradients op. self._dynamic_learning_rates = list({ table.optimizer.learning_rate for table in self._table_config if callable(table.optimizer.learning_rate)}) # We need to list of host devices for the load/retrieve operations. self._hosts = get_list_of_hosts(self._strategy) self._built = False def build(self, per_replica_batch_size: Optional[int] = None): """Create the underlying variables and initializes the TPU for embeddings. This method creates the underlying variables (including slot variables). If created under a TPUStrategy, this will also initialize the TPU for embeddings. This function will automatically get called by enqueue, which will try to determine your batch size automatically. If this fails, you must manually call this method before you call enqueue. Args: per_replica_batch_size: The per replica batch size that you intend to use. Note that is fixed and the same batch size must be used for both training and evaluation. If you want to calculate this from the global batch size, you can use `num_replicas_in_sync` property of your strategy object. May be set to None if not created under a TPUStrategy. Raises: ValueError: If per_replica_batch_size is None and object was created in a TPUStrategy scope. """ if self._built: return if self._using_tpu: if per_replica_batch_size is None: raise ValueError("You must specify a per_replica_batch_size when " "calling build if object is created under a " "TPUStrategy.") self._batch_size = per_replica_batch_size self._config_proto = self._create_config_proto() logging.info("Initializing TPU Embedding engine with config: %s", self._config_proto) @def_function.function def load_config(): tpu.initialize_system_for_tpu_embedding(self._config_proto) load_config() logging.info("Done initializing TPU Embedding engine.") # Create and load variables and slot variables into the TPU. # Note that this is a dict of dicts. Keys to the first dict are table names. # We would prefer to use TableConfigs, but then these variables won't be # properly tracked by the tracking API. self._variables = self._create_variables_and_slots() self._built = True # This is internally conditioned self._built and self._using_tpu self._load_variables() def _maybe_build(self, batch_size: Optional[int]): if not self._built: # This can be called while tracing a function, so we wrap the # initialization code with init_scope so it runs eagerly, this means that # it will not be included the function graph generated by tracing so that # we can be sure that we only initialize the TPU for embeddings exactly # once. with ops.init_scope(): self.build(batch_size) @property def embedding_tables( self ) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: """Returns a dict of embedding tables, keyed by `TableConfig`. This property only works when the `TPUEmbedding` object is created under a non-TPU strategy. This is intended to be used to for CPU based lookup when creating a serving checkpoint. Returns: A dict of embedding tables, keyed by `TableConfig`. Raises: RuntimeError: If object was created under a `TPUStrategy`. """ # We don't support returning tables on TPU due to their sharded nature and # the fact that when using a TPUStrategy: # 1. Variables are stale and are only updated when a checkpoint is made. # 2. Updating the variables won't affect the actual tables on the TPU. if self._using_tpu: if save_context.in_save_context(): return {table: self._variables[table.name]["parameters"].variables[0] for table in self._table_config} raise RuntimeError("Unable to retrieve embedding tables when using a TPU " "strategy. If you need access, save your model, " "create this object under a CPU strategy and restore.") self._maybe_build(None) # Only return the tables and not the slot variables. On CPU this are honest # tf.Variables. return {table: self._variables[table.name]["parameters"] for table in self._table_config} def _create_config_proto( self ) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration: """Creates the TPUEmbeddingConfiguration proto. This proto is used to initialize the TPU embedding engine. Returns: A TPUEmbeddingConfiguration proto. """ config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() # There are several things that need to be computed here: # 1. Each table has a num_features, which corresponds to the number of # output rows per example for this table. Sequence features count for # their maximum sequence length. # 2. Learning rate index: the index of the dynamic learning rate for this # table (if it exists) in the list we created at initialization. # We don't simply create one learning rate index per table as this has # extremely bad performance characteristics. The more separate # optimization configurations we have, the worse the performance will be. num_features = {table: 0 for table in self._table_config} for feature in nest.flatten(self._feature_config): num_features[feature.table] += (1 if feature.max_sequence_length == 0 else feature.max_sequence_length) # Map each callable dynamic learning rate to its in index in the list. learning_rate_index = {r: i for i, r in enumerate( self._dynamic_learning_rates)} for table in self._table_config: table_descriptor = config_proto.table_descriptor.add() table_descriptor.name = table.name # For small tables, we pad to the number of hosts so that at least one # id will be assigned to each host. table_descriptor.vocabulary_size = max(table.vocabulary_size, self._strategy.extended.num_hosts) table_descriptor.dimension = table.dim table_descriptor.num_features = num_features[table] parameters = table_descriptor.optimization_parameters # We handle the learning rate separately here and don't allow the # optimization class to handle this, as it doesn't know about dynamic # rates. if callable(table.optimizer.learning_rate): parameters.learning_rate.dynamic.tag = ( learning_rate_index[table.optimizer.learning_rate]) else: parameters.learning_rate.constant = table.optimizer.learning_rate # Use optimizer to handle the rest of the parameters. table.optimizer._set_optimization_parameters(parameters) # pylint: disable=protected-access # Always set mode to training, we override the mode during enqueue. config_proto.mode = ( tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING) config_proto.batch_size_per_tensor_core = self._batch_size config_proto.num_hosts = self._strategy.extended.num_hosts config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync # TODO(bfontain): Allow users to pick MOD for the host sharding. config_proto.sharding_strategy = ( tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT) config_proto.pipeline_execution_with_tensor_core = ( self._pipeline_execution_with_tensor_core) return config_proto def _compute_per_table_gradients( self, gradients ) -> Dict[Text, List[core.Tensor]]: """Computes a dict of lists of gradients, keyed by table name. Args: gradients: A nested structure of Tensors (and Nones) with the same structure as the feature config. Returns: A dict of lists of tensors, keyed by the table names, containing the gradients in the correct order with None gradients replaced by zeros. """ nest.assert_same_structure(self._feature_config, gradients) per_table_gradients = {table: [] for table in self._table_config} for (path, gradient), feature in zip( nest.flatten_with_joined_string_paths(gradients), nest.flatten(self._feature_config)): if gradient is not None and not isinstance(gradient, ops.Tensor): raise ValueError( "Found {} at path {} in gradients. Expected Tensor.".format( type(gradient), path)) # Expected tensor shape differs for sequence and non-sequence features. if feature.max_sequence_length > 0: shape = [self._batch_size, feature.max_sequence_length, feature.table.dim] else: shape = [self._batch_size, feature.table.dim] if gradient is not None: if gradient.shape != shape: raise ValueError("Found gradient of shape {} at path {}. Expected " "shape {}.".format(gradient.shape, path, shape)) # We expand dims on non-sequence features so that all features are # of rank 3 and we can concat on axis=1. if len(shape) == 2: gradient = array_ops.expand_dims(gradient, axis=1) else: # No gradient for this feature, since we must give a gradient for all # features, pass in a zero tensor here. Note that this is not correct # for all optimizers. logging.warn("No gradient passed for feature %s, sending zero " "gradient. This may not be correct behavior for certain " "optimizers like Adam.", path) # Create a shape to mimic the expand_dims above for non-sequence # features. if len(shape) == 2: shape = [shape[0], 1, shape[1]] gradient = array_ops.zeros(shape, dtype=dtypes.float32) per_table_gradients[feature.table].append(gradient) return per_table_gradients def apply_gradients(self, gradients, name: Text = None): """Applies the gradient update to the embedding tables. If a gradient of `None` is passed in any position of the nested structure, then an gradient update with a zero gradient is applied for that feature. For optimizers like SGD or Adagrad, this is the same as applying no update at all. For lazy Adam and other sparsely applied optimizers with decay, ensure you understand the effect of applying a zero gradient. ```python strategy = tf.distribute.TPUStrategy(...) with strategy.scope(): embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) distributed_dataset = ( strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) dataset_iterator = iter(distributed_dataset) @tf.function def training_step(): def tpu_step(tpu_features): with tf.GradientTape() as tape: activations = embedding.dequeue() tape.watch(activations) loss = ... # some computation involving activations embedding_gradients = tape.gradient(loss, activations) embedding.apply_gradients(embedding_gradients) embedding_features, tpu_features = next(dataset_iterator) embedding.enqueue(embedding_features, training=True) strategy.run(tpu_step, args=(embedding_features, )) training_step() ``` Args: gradients: A nested structure of gradients, with structure matching the `feature_config` passed to this object. name: A name for the underlying op. Raises: RuntimeError: If called when object wasn't created under a `TPUStrategy` or if not built (either by manually calling build or calling enqueue). ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a `tf.Tensor` of the incorrect shape is passed in. Also if the size of any sequence in `gradients` does not match corresponding sequence in `feature_config`. TypeError: If the type of any sequence in `gradients` does not match corresponding sequence in `feature_config`. """ if not self._using_tpu: raise RuntimeError("apply_gradients is not valid when TPUEmbedding " "object is not created under a TPUStrategy.") if not self._built: raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding " "object. Please either call enqueue first or manually " "call the build method.") # send_tpu_embedding_gradients requires per table gradient, if we only have # one feature per table this isn't an issue. When multiple features share # the same table, the order of the features in per table tensor returned by # recv_tpu_embedding_activations matches the order in which they were passed # to enqueue. # In all three places, we use the fixed order given by nest.flatten to have # a consistent feature order. # First construct a dict of tensors one for each table. per_table_gradients = self._compute_per_table_gradients(gradients) # Now that we have a list of gradients we can compute a list of gradients # in the fixed order of self._table_config which interleave the gradients of # the individual features. We concat on axis 1 and then reshape into a 2d # tensor. The send gradients op expects a tensor of shape # [num_features*batch_size, dim] for each table. interleaved_gradients = [] for table in self._table_config: interleaved_gradients.append(array_ops.reshape( array_ops.concat(per_table_gradients[table], axis=1), [-1, table.dim])) op = tpu_ops.send_tpu_embedding_gradients( inputs=interleaved_gradients, learning_rates=[math_ops.cast(fn(), dtype=dtypes.float32) for fn in self._dynamic_learning_rates], config=self._config_proto.SerializeToString()) # Apply the name tag to the op. if name is not None: _add_key_attr(op, name) def dequeue(self, name: Text = None): """Get the embedding results. Returns a nested structure of `tf.Tensor` objects, matching the structure of the `feature_config` argument to the `TPUEmbedding` class. The output shape of the tensors is `(batch_size, dim)`, where `batch_size` is the per core batch size, `dim` is the dimension of the corresponding `TableConfig`. If the feature's corresponding `FeatureConfig` has `max_sequence_length` greater than 0, the output will be a sequence of shape `(batch_size, max_sequence_length, dim)` instead. ```python strategy = tf.distribute.TPUStrategy(...) with strategy.scope(): embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) distributed_dataset = ( strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) dataset_iterator = iter(distributed_dataset) @tf.function def training_step(): def tpu_step(tpu_features): with tf.GradientTape() as tape: activations = embedding.dequeue() tape.watch(activations) loss = ... # some computation involving activations embedding_gradients = tape.gradient(loss, activations) embedding.apply_gradients(embedding_gradients) embedding_features, tpu_features = next(dataset_iterator) embedding.enqueue(embedding_features, training=True) strategy.run(tpu_step, args=(embedding_features, )) training_step() ``` Args: name: A name for the underlying op. Returns: A nested structure of tensors, with the same structure as `feature_config` passed to this instance of the `TPUEmbedding` object. Raises: RuntimeError: If called when object wasn't created under a `TPUStrategy` or if not built (either by manually calling build or calling enqueue). """ if not self._using_tpu: raise RuntimeError("dequeue is not valid when TPUEmbedding object is not " "created under a TPUStrategy.") if not self._built: raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. " "Please either call enqueue first or manually call " "the build method.") # The activations returned by this op are per table. So we must separate # them out into per feature activations. The activations are interleaved: # for each table, we expect a [num_features*batch_size, dim] tensor. # E.g. we expect the slice [:num_features, :] to contain the lookups for the # first example of all features using this table. activations = tpu_ops.recv_tpu_embedding_activations( num_outputs=len(self._table_config), config=self._config_proto.SerializeToString()) # Apply the name tag to the op. if name is not None: _add_key_attr(activations[0].op, name) # Compute the number of features for this table. num_features = {table: 0 for table in self._table_config} for feature in nest.flatten(self._feature_config): num_features[feature.table] += (1 if feature.max_sequence_length == 0 else feature.max_sequence_length) # Activations are reshaped so that they are indexed by batch size and then # by the 'feature' index within the batch. The final dimension should equal # the dimension of the table. table_to_activation = { table: array_ops.reshape(activation, [self._batch_size, num_features[table], -1]) for table, activation in zip(self._table_config, activations)} # We process the features in the same order we enqueued them. # For each feature we take the next slice of the activations, so need to # track the activations and the current position we are in. table_to_position = {table: 0 for table in self._table_config} per_feature_activations = [] for feature in nest.flatten(self._feature_config): activation = table_to_activation[feature.table] feature_index = table_to_position[feature.table] # We treat non-sequence and sequence features differently here as sequence # features have rank 3 while non-sequence features have rank 2. if feature.max_sequence_length == 0: per_feature_activations.append( activation[:, feature_index, :]) table_to_position[feature.table] += 1 else: per_feature_activations.append( activation[:, feature_index:( feature_index+feature.max_sequence_length), :]) table_to_position[feature.table] += feature.max_sequence_length # Pack the list back into the same nested structure as the features. return nest.pack_sequence_as(self._feature_config, per_feature_activations) def _create_variables_and_slots( self ) -> Dict[Text, Dict[Text, tf_variables.Variable]]: """Create variables for TPU embeddings. Note under TPUStrategy this will ensure that all creations happen within a variable creation scope of the sharded variable creator. Returns: A dict of dicts. The outer dict is keyed by the table names and the inner dicts are keyed by 'parameters' and the slot variable names. """ def create_variables(table): """Create all variables.""" variable_shape = (table.vocabulary_size, table.dim) def getter(name, shape, dtype, initializer, trainable): del shape # _add_variable_with_custom_getter clears the shape sometimes, so we # take the global shape from outside the getter. initial_value = functools.partial(initializer, variable_shape, dtype=dtype) return tf_variables.Variable( name=name, initial_value=initial_value, shape=variable_shape, dtype=dtype, trainable=trainable) def variable_creator(name, initializer, trainable=True): # use add_variable_with_custom_getter here so that we take advantage of # the checkpoint loading to allow restore before the variables get # created which avoids double initialization. return self._add_variable_with_custom_getter( name=name, initializer=initializer, shape=variable_shape, dtype=dtypes.float32, getter=getter, trainable=trainable) parameters = variable_creator(table.name, table.initializer, trainable=not self._using_tpu) def slot_creator(name, initializer): return variable_creator(table.name + "/" + name, initializer, False) if table.optimizer is not None: slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access else: slot_vars = {} slot_vars["parameters"] = parameters return slot_vars # Store tables based on name rather than TableConfig as we can't track # through dicts with non-string keys, i.e. we won't be able to save. variables = {} for table in self._table_config: if not self._using_tpu: variables[table.name] = create_variables(table) else: with variable_scope.variable_creator_scope( make_sharded_variable_creator(self._hosts)): variables[table.name] = create_variables(table) return variables def _load_variables(self): # Only load the variables if we are: # 1) Using TPU # 2) Variables are created # 3) Not in save context (except if running eagerly) if self._using_tpu and self._built and not ( not context.executing_eagerly() and save_context.in_save_context()): _load_variables_impl(self._config_proto.SerializeToString(), self._hosts, self._variables, self._table_config) def _retrieve_variables(self): # Only retrieve the variables if we are: # 1) Using TPU # 2) Variables are created # 3) Not in save context (except if running eagerly) if self._using_tpu and self._built and not ( not context.executing_eagerly() and save_context.in_save_context()): _retrieve_variables_impl(self._config_proto.SerializeToString(), self._hosts, self._variables, self._table_config) def _gather_saveables_for_checkpoint( self ) -> Dict[Text, Callable[[Text], "TPUEmbeddingSaveable"]]: """Overrides default Trackable implementation to add load/retrieve hook.""" # This saveable should be here in both TPU and CPU checkpoints, so when on # CPU, we add the hook with no functions. # TODO(bfontain): Update restore logic in saver so that these hooks are # always executed. Once that is done, we can output an empty list when on # CPU. def factory(name=_HOOK_KEY): return TPUEmbeddingSaveable(name, self._load_variables, self._retrieve_variables) return {_HOOK_KEY: factory} # Some helper functions for the below enqueue function. def _add_data_for_tensor(self, tensor, weight, indices, values, weights, int_zeros, float_zeros, path): if weight is not None: raise ValueError( "Weight specified for dense input {}, which is not allowed. " "Weight will always be 1 in this case.".format(path)) # For tensors, there are no indices and no weights. indices.append(int_zeros) values.append(math_ops.cast(tensor, dtypes.int32)) weights.append(float_zeros) def _add_data_for_sparse_tensor(self, tensor, weight, indices, values, weights, int_zeros, float_zeros, path): indices.append(math_ops.cast(tensor.indices, dtypes.int32)) values.append(math_ops.cast(tensor.values, dtypes.int32)) # If we have weights they must be a SparseTensor. if weight is not None: if not isinstance(weight, sparse_tensor.SparseTensor): raise ValueError("Weight for {} is type {} which does not match " "type input which is SparseTensor.".format( path, type(weight))) weights.append(math_ops.cast(weight.values, dtypes.float32)) else: weights.append(float_zeros) def _add_data_for_ragged_tensor(self, tensor, weight, indices, values, weights, int_zeros, float_zeros, path): indices.append(math_ops.cast(tensor.row_splits, dtypes.int32)) values.append(math_ops.cast(tensor.values, dtypes.int32)) # If we have weights they must be a RaggedTensor. if weight is not None: if not isinstance(weight, ragged_tensor.RaggedTensor): raise ValueError("Weight for {} is type {} which does not match " "type input which is RaggedTensor.".format( path, type(weight))) weights.append(math_ops.cast(weight.values, dtypes.float32)) else: weights.append(float_zeros) def _generate_enqueue_op( self, flat_inputs: List[internal_types.NativeObject], flat_weights: List[Optional[internal_types.NativeObject]], flat_features: List[tpu_embedding_v2_utils.FeatureConfig], device_ordinal: int, mode_override: Text ) -> ops.Operation: """Outputs a the enqueue op given the inputs and weights. Args: flat_inputs: A list of input tensors. flat_weights: A list of input weights (or None) of the same length as flat_inputs. flat_features: A list of FeatureConfigs of the same length as flat_inputs. device_ordinal: The device to create the enqueue op for. mode_override: A tensor containing the string "train" or "inference". Returns: The enqueue op. """ # First we need to understand which op to use. This depends on if sparse # or ragged tensors are in the flat_inputs. sparse = False ragged = False for inp in flat_inputs: if isinstance(inp, sparse_tensor.SparseTensor): sparse = True elif isinstance(inp, ragged_tensor.RaggedTensor): ragged = True if sparse and ragged: raise ValueError( "Found both SparseTensors and RaggedTensors in the input to the " "enqueue operation. Please ensure that your data does not include " "both SparseTensors and RaggedTensors. It is ok to have Tensors in " "combination with one of the previous types.") # Combiners are per table, list in the same order as the table order. combiners = [table.combiner for table in self._table_config] # Reverse mapping of self._table_config, so that we can lookup the table # index. table_to_id = {table: i for i, table in enumerate(self._table_config)} # These parallel arrays will be the inputs to the enqueue op. indices = [] # sample_indices for sparse, sample_splits for ragged. values = [] weights = [] table_ids = [] max_sequence_lengths = [] # We have to supply a empty/zero tensor in a list position where we don't # have data (e.g. indices for standard Tensor input, weight when no weight # is specified). We create one op here per call, so that we reduce the # graph size. int_zeros = array_ops.zeros((0,), dtype=dtypes.int32) float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) # In the following loop we insert casts so that everything is either int32 # or float32. This is because op inputs which are lists of tensors must be # of the same type within the list. Moreover the CPU implementions of these # ops cast to these types anyway, so we don't lose any data by casting # early. for inp, weight, (path, feature) in zip( flat_inputs, flat_weights, flat_features): table_ids.append(table_to_id[feature.table]) max_sequence_lengths.append(feature.max_sequence_length) if isinstance(inp, ops.Tensor): self._add_data_for_tensor(inp, weight, indices, values, weights, int_zeros, float_zeros, path) elif isinstance(inp, sparse_tensor.SparseTensor): self._add_data_for_sparse_tensor(inp, weight, indices, values, weights, int_zeros, float_zeros, path) elif isinstance(inp, ragged_tensor.RaggedTensor): self._add_data_for_ragged_tensor(inp, weight, indices, values, weights, int_zeros, float_zeros, path) else: raise ValueError("Input {} is of unknown type {}. Please only pass " "Tensor, SparseTensor or RaggedTensor as input to " "enqueue.".format(path, type(inp))) if ragged: return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( sample_splits=indices, embedding_indices=values, aggregation_weights=weights, mode_override=mode_override, device_ordinal=device_ordinal, combiners=combiners, table_ids=table_ids, max_sequence_lengths=max_sequence_lengths) return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( sample_indices=indices, embedding_indices=values, aggregation_weights=weights, mode_override=mode_override, device_ordinal=device_ordinal, combiners=combiners, table_ids=table_ids, max_sequence_lengths=max_sequence_lengths) def _raise_error_for_incorrect_control_flow_context(self): """Raises an error if we are not in the TPUReplicateContext.""" # Do not allow any XLA control flow (i.e. control flow in between a # TPUStrategy's run call and the call to this function), as we can't # extract the enqueue from the head when in XLA control flow. graph = ops.get_default_graph() in_tpu_ctx = False while graph is not None: ctx = graph._get_control_flow_context() # pylint: disable=protected-access while ctx is not None: if isinstance(ctx, tpu.TPUReplicateContext): in_tpu_ctx = True break ctx = ctx.outer_context if in_tpu_ctx: break graph = getattr(graph, "outer_graph", None) if graph != ops.get_default_graph() and in_tpu_ctx: raise RuntimeError( "Current graph {} does not match graph which contains " "TPUReplicateContext {}. This is most likely due to the fact that " "enqueueing embedding data is called inside control flow or a " "nested function inside `strategy.run`. This is not supported " "because outside compilation fails to extract the enqueue ops as " "head of computation.".format(ops.get_default_graph(), graph)) return in_tpu_ctx def _raise_error_for_non_direct_inputs(self, features): """Checks all tensors in features to see if they are a direct input.""" # expand_composites here is important: as composite tensors pass through # tpu.replicate, they get 'flattened' into their component tensors and then # repacked before being passed to the tpu function. In means that it is the # component tensors which are produced by an op with the # "_tpu_input_identity" attribute. for path, input_tensor in nest.flatten_with_joined_string_paths( features, expand_composites=True): if input_tensor.op.type == "Placeholder": continue try: is_input = input_tensor.op.get_attr("_tpu_input_identity") except ValueError: is_input = False if not is_input: raise ValueError( "Received input tensor {} which is the output of op {} (type {}) " "which does not have the `_tpu_input_identity` attr. Please " "ensure that the inputs to this layer are taken directly from " "the arguments of the function called by " "strategy.run. Two possible causes are: dynamic batch size " "support or you are using a keras layer and are not passing " "tensors which match the dtype of the `tf.keras.Input`s." "If you are triggering dynamic batch size support, you can " "disable it by passing tf.distribute.RunOptions(" "experimental_enable_dynamic_batch_size=False) to the options " "argument of strategy.run().".format(path, input_tensor.op.name, input_tensor.op.type)) def _raise_error_for_inputs_not_on_cpu(self, features): """Checks all tensors in features to see are placed on the CPU.""" def check_device(path, device_string): spec = tf_device.DeviceSpec.from_string(device_string) if spec.device_type == "TPU": raise ValueError( "Received input tensor {} which is on a TPU input device {}. Input " "tensors for TPU embeddings must be placed on the CPU. Please " "ensure that your dataset is prefetching tensors to the host by " "setting the 'experimental_prefetch_to_device' option of the " "dataset distribution function. See the documentation of the " "enqueue method for an example.".format( path, device_string)) # expand_composites here is important, we need to check the device of each # underlying tensor. for path, input_tensor in nest.flatten_with_joined_string_paths( features, expand_composites=True): if (input_tensor.op.type == "Identity" and input_tensor.op.inputs[0].op.type == "TPUReplicatedInput"): for tensor in input_tensor.op.inputs[0].op.inputs: check_device(path, tensor.device) else: check_device(path, input_tensor.device) def enqueue( self, features, weights=None, training: bool = True, name: Optional[Text] = None): """Enqueues id tensors for embedding lookup. This function enqueues a structure of features to be looked up in the embedding tables. We expect that the batch size of each of the tensors in features matches the per core batch size. This will automatically happen if your input dataset is batched to the global batch size and you use `tf.distribute.TPUStrategy`'s `experimental_distribute_dataset` or if you use `distribute_datasets_from_function` and batch to the per core batch size computed by the context passed to your input function. ```python strategy = tf.distribute.TPUStrategy(...) with strategy.scope(): embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) distributed_dataset = ( strategy.distribute_datasets_from_function( dataset_fn=..., options=tf.distribute.InputOptions( experimental_prefetch_to_device=False)) dataset_iterator = iter(distributed_dataset) @tf.function def training_step(): def tpu_step(tpu_features): with tf.GradientTape() as tape: activations = embedding.dequeue() tape.watch(activations) loss = ... # some computation involving activations embedding_gradients = tape.gradient(loss, activations) embedding.apply_gradients(embedding_gradients) embedding_features, tpu_features = next(dataset_iterator) embedding.enqueue(embedding_features, training=True) strategy.run(tpu_step, args=(embedding_features,)) training_step() ``` NOTE: You should specify `training=True` when using `embedding.apply_gradients` as above and `training=False` when not using `embedding.apply_gradients` (e.g. for frozen embeddings or when doing evaluation). Args: features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor` or `tf.RaggedTensor` is supported per call. weights: If not `None`, a nested structure of `tf.Tensor`s, `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except that the tensors should be of float type (and they will be downcast to `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the same for the parallel entries from `features` and similarly for `tf.RaggedTensor`s we assume the row_splits are the same. training: Defaults to `True`. If `False`, enqueue the batch as inference batch (forward pass only). Do not call `apply_gradients` when this is `False` as this may lead to a deadlock. name: A name for the underlying op. Raises: ValueError: When called inside a strategy.run call and input is not directly taken from the args of the `strategy.run` call. Also if the size of any sequence in `features` does not match corresponding sequence in `feature_config`. Similarly for `weights`, if not `None`. If batch size of features is unequal or different from a previous call. RuntimeError: When called inside a strategy.run call and inside XLA control flow. If batch_size is not able to be determined and build was not called. TypeError: If the type of any sequence in `features` does not match corresponding sequence in `feature_config`. Similarly for `weights`, if not `None`. """ if not self._using_tpu: raise RuntimeError("enqueue is not valid when TPUEmbedding object is not " "created under a TPUStrategy.") in_tpu_context = self._raise_error_for_incorrect_control_flow_context() # Should we also get batch_size from weights if they exist? # Since features is assumed to be batched at the per replica batch size # the returned batch size here is per replica an not global. batch_size = self._get_batch_size(features, in_tpu_context) if batch_size is None and not self._built: raise RuntimeError("Unable to determine batch size from input features." "Please call build() with global batch size to " "initialize the TPU for embeddings.") if batch_size is not None: self._maybe_build(batch_size) if self._batch_size != batch_size: raise ValueError("Multiple calls to enqueue with different batch sizes " "{} and {}.".format(self._batch_size, batch_size)) nest.assert_same_structure(self._feature_config, features) flat_inputs = nest.flatten(features) flat_weights = [None] * len(flat_inputs) if weights is not None: nest.assert_same_structure(self._feature_config, weights) flat_weights = nest.flatten(weights) flat_features = nest.flatten_with_joined_string_paths(self._feature_config) self._raise_error_for_inputs_not_on_cpu(features) # If we are in a tpu_context, automatically apply outside compilation. if in_tpu_context: self._raise_error_for_non_direct_inputs(features) def generate_enqueue_ops(): """Generate enqueue ops for outside compilation.""" # Note that we put array_ops.where_v2 rather than a python if so that # the op is explicitly create and the constant ops are both in the graph # even though we don't expect training to be a tensor (and thus generate # control flow automatically). This need to make it easier to re-write # the graph later if we need to fix which mode needs to be used. mode_override = array_ops.where_v2(training, constant_op.constant("train"), constant_op.constant("inference")) # Device ordinal is -1 here, a later rewrite will fix this once the op # is expanded by outside compilation. enqueue_op = self._generate_enqueue_op( flat_inputs, flat_weights, flat_features, device_ordinal=-1, mode_override=mode_override) # Apply the name tag to the op. if name is not None: _add_key_attr(enqueue_op, name) # Ensure that this op has outbound control flow, otherwise it won't be # executed. ops.get_default_graph().control_outputs.append(enqueue_op) tpu.outside_compilation(generate_enqueue_ops) else: mode_override = "train" if training else "inference" # We generate enqueue ops per device, so we need to gather the all # features for a single device in to a dict. # We rely here on the fact that the devices in the PerReplica value occur # in the same (standard) order as self._strategy.extended.worker_devices. enqueue_ops = [] for replica_id in range(self._strategy.num_replicas_in_sync): replica_inputs = distribute_utils.select_replica(replica_id, flat_inputs) replica_weights = distribute_utils.select_replica(replica_id, flat_weights) tpu_device = self._strategy.extended.worker_devices[replica_id] # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 # the device ordinal is the last number device_ordinal = int(tpu_device.rsplit(":", 1)[1]) with ops.device(device_util.get_host_for_device(tpu_device)): enqueue_op = self._generate_enqueue_op( replica_inputs, replica_weights, flat_features, device_ordinal=device_ordinal, mode_override=mode_override) # Apply the name tag to the op. if name is not None: _add_key_attr(enqueue_op, name) enqueue_ops.append(enqueue_op) ops.get_default_graph().control_outputs.extend(enqueue_ops) def _get_batch_size(self, tensors, in_tpu_context: bool): """Gets the batch size from a nested structure of features.""" batch_size = None for path, maybe_tensor in nest.flatten_with_joined_string_paths(tensors): tensor_list = [] if not in_tpu_context: # if we are not in a context, then this is PerReplica and we need to # check each replica's batch size. for replica_id in range(self._strategy.num_replicas_in_sync): tensor_list.append(distribute_utils.select_replica(replica_id, maybe_tensor)) else: tensor_list = [maybe_tensor] for tensor in tensor_list: if tensor.shape.rank < 1: raise ValueError( "Input {} has rank 0, rank must be at least 1.".format(path)) shape = tensor.shape.as_list() if shape[0] is not None: if batch_size is None: batch_size = shape[0] elif batch_size != shape[0]: raise ValueError("Found multiple batch sizes {} and {}. All inputs " "must have the same batch dimensions size.".format( batch_size, shape[0])) return batch_size @def_function.function def _load_variables_impl( config: Text, hosts: List[Tuple[int, Text]], variables: Dict[Text, Dict[Text, tf_variables.Variable]], table_config: tpu_embedding_v2_utils.TableConfig): """Load embedding tables to onto TPU for each table and host. Args: config: A serialized TPUEmbeddingConfiguration proto. hosts: A list of CPU devices, on per host. variables: A dictionary of dictionaries of TPUShardedVariables. First key is the table name, second key is 'parameters' or the optimizer slot name. table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. """ def select_fn(host_id): def select_or_zeros(x): if host_id >= len(x.variables): # In the edge case where we have more hosts than variables, due to using # a small number of rows, we load zeros for the later hosts. We copy # the shape of the first host's variables, which we assume is defined # because TableConfig guarantees at least one row. return array_ops.zeros_like(x.variables[0]) return x.variables[host_id] return select_or_zeros for host_id, host in enumerate(hosts): with ops.device(host): host_variables = nest.map_structure(select_fn(host_id), variables) for table in table_config: table.optimizer._load()( # pylint: disable=protected-access table_name=table.name, num_shards=len(hosts), shard_id=host_id, config=config, **host_variables[table.name]) # Ensure that only the first table/first host gets a config so that we # don't bloat graph by attaching this large string to each op. # We have num tables * num hosts of these so for models with a large # number of tables training on a large slice, this can be an issue. config = None @def_function.function def _retrieve_variables_impl( config: Text, hosts: List[Tuple[int, Text]], variables: Dict[Text, Dict[Text, tf_variables.Variable]], table_config: tpu_embedding_v2_utils.TableConfig): """Retrieve embedding tables from TPU to host memory. Args: config: A serialized TPUEmbeddingConfiguration proto. hosts: A list of all the host CPU devices. variables: A dictionary of dictionaries of TPUShardedVariables. First key is the table name, second key is 'parameters' or the optimizer slot name. table_config: A list of tf.tpu.experimental.embedding.TableConfig objects. """ for host_id, host in enumerate(hosts): with ops.device(host): for table in table_config: retrieved = table.optimizer._retrieve()( # pylint: disable=protected-access table_name=table.name, num_shards=len(hosts), shard_id=host_id, config=config) # When there are no slot variables (e.g with SGD) this returns a # single tensor rather than a tuple. In this case we put the tensor in # a list to make the following code easier to write. if not isinstance(retrieved, tuple): retrieved = (retrieved,) for i, slot in enumerate(["parameters"] + table.optimizer._slot_names()): # pylint: disable=protected-access # We must assign the CPU variables the values of tensors that were # returned from the TPU. sharded_var = variables[table.name][slot] if host_id < len(sharded_var.variables): # In the edge case where we have more hosts than variables, due to # using a small number of rows, we skip the later hosts. sharded_var.variables[host_id].assign(retrieved[i]) # Ensure that only the first table/first host gets a config so that we # don't bloat graph by attaching this large string to each op. # We have num tables * num hosts of these so for models with a large # number of tables training on a large slice, this can be an issue. config = None class TPUEmbeddingSaveable(saveable_hook.SaveableHook): """Save/Restore hook to Retrieve/Load TPUEmbedding variables.""" def __init__( self, name: Text, load: Callable[[], Any], retrieve: Callable[[], Any]): self._load = load self._retrieve = retrieve super(TPUEmbeddingSaveable, self).__init__(name=name) def before_save(self): if self._retrieve is not None: self._retrieve() def after_restore(self): if self._load is not None: self._load() def _ragged_embedding_lookup_with_reduce( table: tf_variables.Variable, ragged: ragged_tensor.RaggedTensor, weights: ragged_tensor.RaggedTensor, combiner: Text) -> core.Tensor: """Compute a ragged lookup followed by a reduce on axis 1. Args: table: The embedding table. ragged: A RaggedTensor of ids to look up. weights: A RaggedTensor of weights (or None). combiner: One of "mean", "sum", "sqrtn". Returns: A Tensor. """ if weights is None: weights = array_ops.ones_like(ragged, dtype=table.dtype) weights = array_ops.expand_dims(weights, axis=2) ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged) ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1) if combiner == "mean": ragged_result = ragged_result / math_ops.reduce_sum(weights, axis=1) elif combiner == "sqrtn": ragged_result = ragged_result, math_ops.sqrt(math_ops.reduce_sum( weights*weights, axis=1)) return ragged_result @tf_export("tpu.experimental.embedding.serving_embedding_lookup") def cpu_embedding_lookup(inputs, weights, tables, feature_config): """Apply standard lookup ops with `tf.tpu.experimental.embedding` configs. This function is a utility which allows using the `tf.tpu.experimental.embedding` config objects with standard lookup functions. This can be used when exporting a model which uses `tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular `tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and should not be part of your serving graph. Note that TPU specific options (such as `max_sequence_length`) in the configuration objects will be ignored. In the following example we take take a trained model (see the documentation for `tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a saved model with a serving function that will perform the embedding lookup and pass the results to your model: ```python model = model_fn(...) embedding = tf.tpu.experimental.embedding.TPUEmbedding( feature_config=feature_config, batch_size=1024, optimizer=tf.tpu.experimental.embedding.SGD(0.1)) checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) checkpoint.restore(...) @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...), 'feature_two': tf.TensorSpec(...), 'feature_three': tf.TensorSpec(...)}]) def serve_tensors(embedding_featurese): embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup( embedding_features, None, embedding.embedding_tables, feature_config) return model(embedded_features) model.embedding_api = embedding tf.saved_model.save(model, export_dir=..., signatures={'serving_default': serve_tensors}) ``` NOTE: Its important to assign the embedding api object to a member of your model as `tf.saved_model.save` only supports saving variables one `Trackable` object. Since the model's weights are in `model` and the embedding table are managed by `embedding`, we assign `embedding` to and attribute of `model` so that tf.saved_model.save can find the embedding variables. NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will work directly from training. Args: inputs: a nested structure of Tensors, SparseTensors or RaggedTensors. weights: a nested structure of Tensors, SparseTensors or RaggedTensors or None for no weights. If not None, structure must match that of inputs, but entries are allowed to be None. tables: a dict of mapping TableConfig objects to Variables. feature_config: a nested structure of FeatureConfig objects with the same structure as inputs. Returns: A nested structure of Tensors with the same structure as inputs. """ nest.assert_same_structure(inputs, feature_config) flat_inputs = nest.flatten(inputs) flat_weights = [None] * len(flat_inputs) if weights is not None: nest.assert_same_structure(inputs, weights) flat_weights = nest.flatten(weights) flat_features = nest.flatten_with_joined_string_paths(feature_config) outputs = [] for inp, weight, (path, feature) in zip( flat_inputs, flat_weights, flat_features): table = tables[feature.table] if feature.max_sequence_length > 0: raise ValueError("Sequence features unsupported at this time.") if weight is not None: if isinstance(inp, ops.Tensor): raise ValueError( "Weight specified for {}, but input is dense.".format(path)) elif type(weight) is not type(inp): raise ValueError( "Weight for {} is of type {} but it does not match type of the " "input which is {}.".format(path, type(weight), type(inp))) if isinstance(inp, ops.Tensor): outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) elif isinstance(inp, sparse_tensor.SparseTensor): outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2( table, inp, sparse_weights=weight, combiner=feature.table.combiner)) elif isinstance(inp, ragged_tensor.RaggedTensor): outputs.append(_ragged_embedding_lookup_with_reduce( table, inp, weight, feature.table.combiner)) else: raise ValueError("Input {} is type {}. Tensor, SparseTensor or " "RaggedTensor expected.".format(path, type(inp))) return nest.pack_sequence_as(feature_config, outputs) def get_list_of_hosts(strategy: tpu_strategy.TPUStrategy) -> List[Text]: """Returns a sorted list of CPU devices for the remote jobs. Args: strategy: A TPUStrategy object. Returns: A sort list of device strings. """ list_of_hosts = [] # Assume this is sorted by task for tpu_device in strategy.extended.worker_devices: host = device_util.get_host_for_device(tpu_device) if host not in list_of_hosts: list_of_hosts.append(host) assert len(list_of_hosts) == strategy.extended.num_hosts return list_of_hosts def extract_variable_info( kwargs) -> Tuple[Text, Tuple[int, ...], dtypes.DType, Callable[[], Any]]: """Extracts the variable creation attributes from the kwargs. Args: kwargs: a dict of keyword arguments that were passed to a variable creator scope. Returns: A tuple of variable name, shape, dtype, initialization function. """ if (isinstance(kwargs["initial_value"], functools.partial) and ( "shape" in kwargs["initial_value"].keywords or kwargs["initial_value"].args)): # Sometimes shape is passed positionally, sometimes it's passed as a kwarg. if "shape" in kwargs["initial_value"].keywords: shape = kwargs["initial_value"].keywords["shape"] else: shape = kwargs["initial_value"].args[0] return (kwargs["name"], shape, kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), kwargs["initial_value"].func) elif "shape" not in kwargs or kwargs["shape"] is None or not callable( kwargs["initial_value"]): raise ValueError( "Unable to extract initializer function and shape from {}. Please " "either pass a function that expects a shape and dtype as the " "initial value for your variable or functools.partial object with " "the shape and dtype kwargs set. This is needed so that we can " "initialize the shards of the ShardedVariable locally.".format( kwargs["initial_value"])) else: return (kwargs["name"], kwargs["shape"], kwargs["dtype"], kwargs["initial_value"]) def make_sharded_variable_creator( hosts: List[Text]) -> Callable[..., TPUShardedVariable]: """Makes a sharded variable creator given a list of hosts. Args: hosts: a list of tensorflow devices on which to shard the tensors. Returns: A variable creator function. """ def sharded_variable_creator( next_creator: Callable[..., tf_variables.Variable], *args, **kwargs): """The sharded variable creator.""" kwargs["skip_mirrored_creator"] = True num_hosts = len(hosts) name, shape, dtype, unwrapped_initial_value = extract_variable_info(kwargs) initial_value = kwargs["initial_value"] rows = shape[0] cols = shape[1] partial_partition = rows % num_hosts full_rows_per_host = rows // num_hosts # We partition as if we were using MOD sharding: at least # `full_rows_per_host` rows to `num_hosts` hosts, where the first # `partial_partition` hosts get an additional row when the number of rows # is not cleanly divisible. Note that `full_rows_per_host` may be zero. partitions = ( [full_rows_per_host + 1] * partial_partition + [full_rows_per_host] * (num_hosts - partial_partition)) variables = [] sharding_aware = "shard_info" in tf_inspect.getargspec(initial_value).args # Keep track of offset for sharding aware initializers. offset = 0 kwargs["dtype"] = dtype for i, p in enumerate(partitions): if p == 0: # Skip variable creation for empty partitions, resulting from the edge # case of 'rows < num_hosts'. This is safe because both load/restore # can handle the missing values. continue with ops.device(hosts[i]): kwargs["name"] = "{}_{}".format(name, i) kwargs["shape"] = (p, cols) if sharding_aware: shard_info = base.ShardInfo(kwargs["shape"], (offset, 0)) kwargs["initial_value"] = functools.partial( initial_value, shard_info=shard_info) offset += p else: kwargs["initial_value"] = functools.partial( unwrapped_initial_value, kwargs["shape"], dtype=dtype) variables.append(next_creator(*args, **kwargs)) return TPUShardedVariable(variables, name=name) return sharded_variable_creator