# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """TfLite LSTMCell wrapper. TODO(renjieliu): Find a better home for this one. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.lite.python.op_hint import OpHint from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.rnn import _best_effort_input_batch_size from tensorflow.python.ops.rnn import _dynamic_rnn_loop from tensorflow.python.ops.rnn import _should_cache from tensorflow.python.ops.rnn import _transpose_batch_time from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @tf_export(v1=["lite.experimental.nn.dynamic_rnn"]) def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=True, scope=None): """Creates a recurrent neural network specified by RNNCell `cell`. Performs fully dynamic unrolling of `inputs`. Example: ```python # create a BasicRNNCell rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] # defining initial state initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) # 'state' is a tensor of shape [batch_size, cell_state_size] outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, initial_state=initial_state, dtype=tf.float32) ``` ```python # create 2 LSTMCells rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] # create a RNN cell composed sequentially of a number of RNNCells multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) # 'outputs' is a tensor of shape [batch_size, max_time, 256] # 'state' is a N-tuple where N is the number of LSTMCells containing a # tf.nn.rnn_cell.LSTMStateTuple for each cell outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=data, dtype=tf.float32) ``` Args: cell: An instance of RNNCell. inputs: The RNN inputs. If `time_major == False` (default), this must be a `Tensor` of shape: `[batch_size, max_time, ...]`, or a nested tuple of such elements. If `time_major == True`, this must be a `Tensor` of shape: `[max_time, batch_size, ...]`, or a nested tuple of such elements. This may also be a (possibly nested) tuple of Tensors satisfying this property. The first two dimensions must match across all the inputs, but otherwise the ranks and other shape components may differ. In this case, input to `cell` at each time-step will replicate the structure of these tuples, except for the time dimension (from which the time is taken). The input to `cell` at each time step will be a `Tensor` or (possibly nested) tuple of Tensors each with dimensions `[batch_size, ...]`. sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used to copy-through state and zero-out outputs when past a batch element's sequence length. So it's more for performance than correctness. initial_state: (optional) An initial state for the RNN. If `cell.state_size` is an integer, this must be a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this should be a tuple of tensors having shapes `[batch_size, s] for s in cell.state_size`. dtype: (optional) The data type for the initial state and expected output. Required if initial_state is not provided or RNN state has a heterogeneous dtype. parallel_iterations: (Default: 32). The number of iterations to run in parallel. Those operations which do not have any temporal dependency and can be run in parallel, will be. This parameter trades off time for space. Values >> 1 use more memory but take less time, while smaller values use less memory but computations take longer. swap_memory: Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, with very minimal (or no) performance penalty. time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to "rnn". Returns: A pair (outputs, state) where: outputs: The RNN output `Tensor`. If time_major == False (default), this will be a `Tensor` shaped: `[batch_size, max_time, cell.output_size]`. If time_major == True, this will be a `Tensor` shaped: `[max_time, batch_size, cell.output_size]`. Note, if `cell.output_size` is a (possibly nested) tuple of integers or `TensorShape` objects, then `outputs` will be a tuple having the same structure as `cell.output_size`, containing Tensors having shapes corresponding to the shape data in `cell.output_size`. state: The final state. If `cell.state_size` is an int, this will be shaped `[batch_size, cell.state_size]`. If it is a `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. If it is a (possibly nested) tuple of ints or `TensorShape`, this will be a tuple having the corresponding shapes. If cells are `LSTMCells` `state` will be a tuple containing a `LSTMStateTuple` for each cell. Raises: TypeError: If `cell` is not an instance of RNNCell. ValueError: If inputs is None or an empty list. RuntimeError: If not using control flow v2. """ # Currently only support time_major == True case. assert time_major # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or # TfLiteRNNCells. rnn_cell_impl.assert_like_rnncell("cell", cell) if not control_flow_util.ENABLE_CONTROL_FLOW_V2: raise RuntimeError("OpHint dynamic rnn only supports control flow v2.") parent_first_child_input = [{ "parent_ophint_input_index": 0, "first_child_ophint_input_index": 0 }] parent_last_child_output = [{ "parent_output_index": 0, # For LstmCell, the index is 2. # For RnnCell, the index is 1. # So we use -1 meaning it's the last one. "child_output_index": -1 }] internal_children_input_output = [{ "child_input_index": 0, # For LstmCell, the index is 2. # For RnnCell, the index is 1. # So we use -1 meaning it's the last one. "child_output_index": -1 }] inputs_outputs_mappings = { "parent_first_child_input": parent_first_child_input, "parent_last_child_output": parent_last_child_output, "internal_children_input_output": internal_children_input_output } tflite_wrapper = OpHint( "TfLiteDynamicRnn", level=2, children_inputs_mappings=inputs_outputs_mappings) with vs.variable_scope(scope or "rnn") as varscope: # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0) # By default, time_major==False and inputs are batch-major: shaped # [batch, time, depth] # For internal calculations, we transpose to [time, batch, depth] flat_input = nest.flatten(inputs) if not time_major: # (batch, time, depth) => (time, batch, depth) flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) parallel_iterations = parallel_iterations or 32 if sequence_length is not None: sequence_length = math_ops.cast(sequence_length, dtypes.int32) if sequence_length.shape.rank not in (None, 1): raise ValueError( "sequence_length must be a vector of length batch_size, " "but saw shape: %s" % sequence_length.shape) sequence_length = array_ops.identity( # Just to find it in the graph. sequence_length, name="sequence_length") batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: state = initial_state else: if not dtype: raise ValueError("If there is no initial_state, you must give a dtype.") if getattr(cell, "get_initial_state", None) is not None: state = cell.get_initial_state( inputs=None, batch_size=batch_size, dtype=dtype) else: state = cell.zero_state(batch_size, dtype) def _assert_has_shape(x, shape): x_shape = array_ops.shape(x) packed_shape = array_ops.stack(shape) return control_flow_ops.Assert( math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ "Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape ]) if not context.executing_eagerly() and sequence_length is not None: # Perform some shape validation with ops.control_dependencies( [_assert_has_shape(sequence_length, [batch_size])]): sequence_length = array_ops.identity( sequence_length, name="CheckSeqLen") inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) outputs, final_state = _dynamic_rnn_loop( cell, inputs, state, parallel_iterations=parallel_iterations, swap_memory=swap_memory, sequence_length=sequence_length, dtype=dtype) # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. # If we are performing batch-major calculations, transpose output back # to shape [batch, time, depth] if not time_major: # (time, batch, depth) => (batch, time, depth) outputs = nest.map_structure(_transpose_batch_time, outputs) outputs = tflite_wrapper.add_output(outputs, name="outputs") return outputs, final_state def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): """Creates a dynamic version of bidirectional recurrent neural network. Takes input and builds independent forward and backward RNNs. The input_size of forward and backward cell must match. The initial state for both directions is zero by default (but can be set optionally) and no intermediate states are ever returned -- the network is fully unrolled for the given (passed in) length(s) of the sequence(s) or completely unrolled if length(s) is not given. Args: cell_fw: An instance of RNNCell, to be used for forward direction. cell_bw: An instance of RNNCell, to be used for backward direction. inputs: The RNN inputs. If time_major == False (default), this must be a tensor of shape: `[batch_size, max_time, ...]`, or a nested tuple of such elements. If time_major == True, this must be a tensor of shape: `[max_time, batch_size, ...]`, or a nested tuple of such elements. sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, containing the actual lengths for each of the sequences in the batch. If not provided, all batch entries are assumed to be full sequences; and time reversal is applied from time `0` to `max_time` for each sequence. initial_state_fw: (optional) An initial state for the forward RNN. This must be a tensor of appropriate type and shape `[batch_size, cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a tuple of tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. initial_state_bw: (optional) Same as for `initial_state_fw`, but using the corresponding properties of `cell_bw`. dtype: (optional) The data type for the initial states and expected output. Required if initial_states are not provided or RNN states have a heterogeneous dtype. parallel_iterations: (Default: 32). The number of iterations to run in parallel. Those operations which do not have any temporal dependency and can be run in parallel, will be. This parameter trades off time for space. Values >> 1 use more memory but take less time, while smaller values use less memory but computations take longer. swap_memory: Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, with very minimal (or no) performance penalty. time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to "bidirectional_rnn" Returns: A tuple (outputs, output_states) where: outputs: A tuple (output_fw, output_bw) containing the forward and the backward rnn output `Tensor`. If time_major == False (default), output_fw will be a `Tensor` shaped: `[batch_size, max_time, cell_fw.output_size]` and output_bw will be a `Tensor` shaped: `[batch_size, max_time, cell_bw.output_size]`. If time_major == True, output_fw will be a `Tensor` shaped: `[max_time, batch_size, cell_fw.output_size]` and output_bw will be a `Tensor` shaped: `[max_time, batch_size, cell_bw.output_size]`. It returns a tuple instead of a single concatenated `Tensor`, unlike in the `bidirectional_rnn`. If the concatenated one is preferred, the forward and backward outputs can be concatenated as `tf.concat(outputs, 2)`. output_states: A tuple (output_state_fw, output_state_bw) containing the forward and the backward final states of bidirectional rnn. Raises: TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. """ rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction with vs.variable_scope("fw") as fw_scope: output_fw, output_state_fw = dynamic_rnn( cell=cell_fw, inputs=inputs, sequence_length=sequence_length, initial_state=initial_state_fw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=fw_scope) # Backward direction if not time_major: time_axis = 1 batch_axis = 0 else: time_axis = 0 batch_axis = 1 def _reverse(input_, seq_lengths, seq_axis, batch_axis): if seq_lengths is not None: return array_ops.reverse_sequence( input=input_, seq_lengths=seq_lengths, seq_axis=seq_axis, batch_axis=batch_axis) else: return array_ops.reverse(input_, axis=[seq_axis]) with vs.variable_scope("bw") as bw_scope: def _map_reverse(inp): return _reverse( inp, seq_lengths=sequence_length, seq_axis=time_axis, batch_axis=batch_axis) inputs_reverse = nest.map_structure(_map_reverse, inputs) tmp, output_state_bw = dynamic_rnn( cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, initial_state=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=bw_scope) output_bw = _reverse( tmp, seq_lengths=sequence_length, seq_axis=time_axis, batch_axis=batch_axis) outputs = (output_fw, output_bw) output_states = (output_state_fw, output_state_bw) return (outputs, output_states)