# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Extending CheckpointReader for TensorFlow.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python._pywrap_checkpoint_reader import CheckpointReader from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export def error_translator(e): """Translate the tensor_slice_reader.cc errors.""" # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the # issue with throwing python exceptions from C++. error_message = str(e) if 'not found in checkpoint' in error_message or ( 'Failed to find any ' 'matching files for') in error_message: raise errors_impl.NotFoundError(None, None, error_message) elif 'Sliced checkpoints are not supported' in error_message or ( 'Data type ' 'not ' 'supported') in error_message: raise errors_impl.UnimplementedError(None, None, error_message) elif 'Failed to get matching files on' in error_message: raise errors_impl.InvalidArgumentError(None, None, error_message) elif 'Unable to open table file' in error_message: raise errors_impl.DataLossError(None, None, error_message) elif 'Failed to find the saved tensor slices' in error_message: raise errors_impl.InternalError(None, None, error_message) else: raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN) def get_variable_to_dtype_map(self): return { name: dtypes.DType(type_enum) for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access } CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map def has_tensor(self, tensor_str): return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access CheckpointReader.has_tensor = has_tensor def get_tensor(self, tensor_str): """Get the tensor from the Checkpoint object.""" try: return CheckpointReader.CheckpointReader_GetTensor( self, compat.as_bytes(tensor_str)) # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the # issue with throwing python exceptions from C++. except RuntimeError as e: error_translator(e) CheckpointReader.get_tensor = get_tensor # Disable invalid name to keep backwards compatibility with that function. # It was previously exported from py_checkpoint_reader.i which did not conform # to pylint checks. # pylint: disable=invalid-name @tf_export(v1=['train.NewCheckpointReader']) def NewCheckpointReader(filepattern): """A function that returns a CheckPointReader. Args: filepattern: The filename. Returns: A CheckpointReader object. """ try: return CheckpointReader(compat.as_bytes(filepattern)) # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the # issue with throwing python exceptions from C++. except RuntimeError as e: error_translator(e)