# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Python API for save and loading a dataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import multiprocessing from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util.tf_export import tf_export COMPRESSION_GZIP = "GZIP" COMPRESSION_SNAPPY = "NONE" @tf_export("data.experimental.save", v1=[]) def save(dataset, path, compression=None, shard_func=None): """Saves the content of the given dataset. Example usage: >>> import tempfile >>> path = os.path.join(tempfile.gettempdir(), "saved_data") >>> # Save a dataset >>> dataset = tf.data.Dataset.range(2) >>> tf.data.experimental.save(dataset, path) >>> new_dataset = tf.data.experimental.load(path, ... tf.TensorSpec(shape=(), dtype=tf.int64)) >>> for elem in new_dataset: ... print(elem) tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64) The saved dataset is saved in multiple file "shards". By default, the dataset output is divided to shards in a round-robin fashion but custom sharding can be specified via the `shard_func` function. For example, you can save the dataset to using a single shard as follows: ```python dataset = make_dataset() def custom_shard_func(element): return 0 dataset = tf.data.experimental.save( path="/path/to/data", ..., shard_func=custom_shard_func) ``` NOTE: The directory layout and file format used for saving the dataset is considered an implementation detail and may change. For this reason, datasets saved through `tf.data.experimental.save` should only be consumed through `tf.data.experimental.load`, which is guaranteed to be backwards compatible. Args: dataset: The dataset to save. path: Required. A directory to use for saving the dataset. compression: Optional. The algorithm to use to compress data when writing it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. shard_func: Optional. A function to control the mapping of dataset elements to file shards. The function is expected to map elements of the input dataset to int64 shard IDs. If present, the function will be traced and executed as graph computation. """ if shard_func is None: use_shard_func = False shard_func = lambda *x: None # a dummy function that will not be used else: use_shard_func = True wrapped_func = dataset_ops.StructuredFunctionWrapper( shard_func, "save()", input_structure=dataset.element_spec, add_to_graph=False) path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") shard_func = wrapped_func.function shard_func.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access dataset = dataset._apply_options() gen_experimental_dataset_ops.save_dataset( dataset._variant_tensor, path=path, shard_func_other_args=shard_func.captured_inputs, compression=compression, shard_func=shard_func, use_shard_func=use_shard_func) class _LoadDataset(dataset_ops.DatasetSource): """A dataset that loads previously saved dataset.""" def __init__(self, path, element_spec, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path self._element_spec = element_spec self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(element_spec))) variant_tensor = gen_experimental_dataset_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super(_LoadDataset, self).__init__(variant_tensor) def _functions(self): return [self._reader_func] @property def element_spec(self): return self._element_spec @tf_export("data.experimental.load", v1=[]) def load(path, element_spec, compression=None, reader_func=None): """Loads a previously saved dataset. Example usage: >>> import tempfile >>> path = os.path.join(tempfile.gettempdir(), "saved_data") >>> # Save a dataset >>> dataset = tf.data.Dataset.range(2) >>> tf.data.experimental.save(dataset, path) >>> new_dataset = tf.data.experimental.load(path, ... tf.TensorSpec(shape=(), dtype=tf.int64)) >>> for elem in new_dataset: ... print(elem) tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64) Note that to load a previously saved dataset, you need to specify `element_spec` -- a type signature of the elements of the saved dataset, which can be obtained via `tf.data.Dataset.element_spec`. This requirement exists so that shape inference of the loaded dataset does not need to perform I/O. If the default option of sharding the saved dataset was used, the element order of the saved dataset will be preserved when loading it. The `reader_func` argument can be used to specify a custom order in which elements should be loaded from the individual shards. The `reader_func` is expected to take a single argument -- a dataset of datasets, each containing elements of one of the shards -- and return a dataset of elements. For example, the order of shards can be shuffled when loading them as follows: ```python def custom_reader_func(datasets): datasets = datasets.shuffle(NUM_SHARDS) return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) dataset = tf.data.experimental.load( path="/path/to/data", ..., reader_func=custom_reader_func) ``` Args: path: Required. A path pointing to a previously saved dataset. element_spec: Required. A nested structure of `tf.TypeSpec` objects matching the structure of an element of the saved dataset and specifying the type of individual element components. compression: Optional. The algorithm to use to decompress the data when reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. reader_func: Optional. A function to control how to read data from shards. If present, the function will be traced and executed as graph computation. Returns: A `tf.data.Dataset` instance. """ return _LoadDataset( path=path, element_spec=element_spec, compression=compression, reader_func=reader_func)