# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Layer serialization/deserialization functions. """ # pylint: disable=wildcard-import # pylint: disable=unused-import from __future__ import absolute_import from __future__ import division from __future__ import print_function import threading from tensorflow.python import tf2 from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.layers import advanced_activations from tensorflow.python.keras.layers import convolutional from tensorflow.python.keras.layers import convolutional_recurrent from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import cudnn_recurrent from tensorflow.python.keras.layers import dense_attention from tensorflow.python.keras.layers import einsum_dense from tensorflow.python.keras.layers import embeddings from tensorflow.python.keras.layers import local from tensorflow.python.keras.layers import merge from tensorflow.python.keras.layers import multi_head_attention from tensorflow.python.keras.layers import noise from tensorflow.python.keras.layers import normalization from tensorflow.python.keras.layers import normalization_v2 from tensorflow.python.keras.layers import pooling from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent_v2 from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 from tensorflow.python.keras.layers import wrappers from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.keras.layers.preprocessing import category_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding_v1 from tensorflow.python.keras.layers.preprocessing import discretization from tensorflow.python.keras.layers.preprocessing import hashing from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.layers.preprocessing import integer_lookup as preprocessing_integer_lookup from tensorflow.python.keras.layers.preprocessing import integer_lookup_v1 as preprocessing_integer_lookup_v1 from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1 from tensorflow.python.keras.layers.preprocessing import string_lookup as preprocessing_string_lookup from tensorflow.python.keras.layers.preprocessing import string_lookup_v1 as preprocessing_string_lookup_v1 from tensorflow.python.keras.layers.preprocessing import text_vectorization as preprocessing_text_vectorization from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1 as preprocessing_text_vectorization_v1 from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_inspect as inspect from tensorflow.python.util.tf_export import keras_export ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, convolutional_recurrent, core, cudnn_recurrent, dense_attention, embeddings, einsum_dense, local, merge, noise, normalization, pooling, image_preprocessing, preprocessing_integer_lookup_v1, preprocessing_normalization_v1, preprocessing_string_lookup_v1, preprocessing_text_vectorization_v1, recurrent, wrappers, hashing, category_crossing, category_encoding_v1, discretization, multi_head_attention) ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2, preprocessing_integer_lookup, preprocessing_normalization, preprocessing_string_lookup, preprocessing_text_vectorization, category_encoding) # ALL_OBJECTS is meant to be a global mutable. Hence we need to make it # thread-local to avoid concurrent mutations. LOCAL = threading.local() def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in layer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf2.enabled() base_cls = base_layer.Layer generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) # Overwrite certain V1 objects with V2 versions if tf2.enabled(): generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_V2_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) # These deserialization aliases are added for backward compatibility, # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" # were used as class name for v1 and v2 version of BatchNormalization, # respectively. Here we explicitly convert them to their canonical names. LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization LOCAL.ALL_OBJECTS[ 'BatchNormalizationV2'] = normalization_v2.BatchNormalization # Prevent circular dependencies. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top from tensorflow.python.keras.feature_column.sequence_feature_column import SequenceFeatures # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['Input'] = input_layer.Input LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec LOCAL.ALL_OBJECTS['Functional'] = models.Functional LOCAL.ALL_OBJECTS['Model'] = models.Model LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel if tf2.enabled(): from tensorflow.python.keras.feature_column.dense_features_v2 import DenseFeatures # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures else: from tensorflow.python.keras.feature_column.dense_features import DenseFeatures # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures # Merge layers, function versions. LOCAL.ALL_OBJECTS['add'] = merge.add LOCAL.ALL_OBJECTS['subtract'] = merge.subtract LOCAL.ALL_OBJECTS['multiply'] = merge.multiply LOCAL.ALL_OBJECTS['average'] = merge.average LOCAL.ALL_OBJECTS['maximum'] = merge.maximum LOCAL.ALL_OBJECTS['minimum'] = merge.minimum LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate LOCAL.ALL_OBJECTS['dot'] = merge.dot @keras_export('keras.layers.serialize') def serialize(layer): return generic_utils.serialize_keras_object(layer) @keras_export('keras.layers.deserialize') def deserialize(config, custom_objects=None): """Instantiates a layer from a config dictionary. Arguments: config: dict of the form {'class_name': str, 'config': dict} custom_objects: dict mapping class names (or function names) of custom (non-Keras) objects to class/functions Returns: Layer instance (may be Model, Sequential, Network, Layer...) """ populate_deserializable_objects() return generic_utils.deserialize_keras_object( config, module_objects=LOCAL.ALL_OBJECTS, custom_objects=custom_objects, printable_module_name='layer')