# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ====================================== """Hook for asynchronous checkpointing. This hook dispatches checkpoint writing operations in a separate thread to allow execution to continue on the main thread. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import threading import time from tensorflow.core.util import event_pb2 from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.summary_io import SummaryWriterCache class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): """Saves checkpoints every N steps or seconds.""" def __init__(self, checkpoint_dir, save_secs=None, save_steps=None, saver=None, checkpoint_basename="model.ckpt", scaffold=None, listeners=None): """Initializes a `CheckpointSaverHook`. Args: checkpoint_dir: `str`, base directory for the checkpoint files. save_secs: `int`, save every N secs. save_steps: `int`, save every N steps. saver: `Saver` object, used for saving. checkpoint_basename: `str`, base name for the checkpoint files. scaffold: `Scaffold`, use to get saver object. listeners: List of `CheckpointSaverListener` subclass instances. Used for callbacks that run immediately before or after this hook saves the checkpoint. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of `saver` or `scaffold` should be set. """ save_path = os.path.join(checkpoint_dir, checkpoint_basename) logging.info("Create AsyncCheckpointSaverHook saving to path\n%s", save_path) if listeners: logging.info(" with %d listener(s).", len(listeners)) if saver is not None and scaffold is not None: raise ValueError("You cannot provide both saver and scaffold.") self._saver = saver self._save_thread = None self._write_graph_thread = None self._checkpoint_dir = checkpoint_dir self._save_path = save_path self._scaffold = scaffold self._timer = basic_session_run_hooks.SecondOrStepTimer( every_secs=save_secs, every_steps=save_steps) self._listeners = listeners or [] self._steps_per_run = 1 self._summary_writer = None self._global_step_tensor = None self._last_checkpoint_step = None def _set_steps_per_run(self, steps_per_run): self._steps_per_run = steps_per_run def begin(self): self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use CheckpointSaverHook.") for l in self._listeners: l.begin() def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) # We do write graph and saver_def at the first call of before_run. # We cannot do this in begin, since we let other hooks to change graph and # add variables in begin. Graph is finalized after all begin calls. def _write_graph_fn(self): training_util.write_graph( ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir, "graph.pbtxt") self._write_graph_thread = threading.Thread(target=_write_graph_fn, args=[self]) self._write_graph_thread.start() saver_def = self._get_saver().saver_def if self._get_saver() else None graph = ops.get_default_graph() meta_graph_def = meta_graph.create_meta_graph_def( graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) self._summary_writer.add_graph(graph) self._summary_writer.add_meta_graph(meta_graph_def) # The checkpoint saved here is the state at step "global_step". self._save(session, global_step) self._timer.update_last_triggered_step(global_step) def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._global_step_tensor) def after_run(self, run_context, run_values): global_step = run_context.session.run(self._global_step_tensor) if self._timer.should_trigger_for_step(global_step): self._timer.update_last_triggered_step(global_step) logging.info("Triggering checkpoint. %s", global_step) if self._save(run_context.session, global_step): run_context.request_stop() def end(self, session): if self._save_thread: logging.info("Waiting for any pending checkpoints to finish.") self._save_thread.join() if self._write_graph_thread: logging.info("Waiting for any pending write_graph to finish.") self._write_graph_thread.join() last_step = session.run(self._global_step_tensor) if self._last_checkpoint_step != last_step: self._save(session, last_step, asynchronous=False) for l in self._listeners: l.end(session, last_step) def _save(self, session, step, asynchronous=True): """Saves the latest checkpoint, returns should_stop.""" def _save_fn(): """Run the saver process.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) start_time = time.time() for l in self._listeners: l.before_save(session, step) self._get_saver().save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( event_pb2.SessionLog( status=event_pb2.SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) for l in self._listeners: l.after_save(session, step) end_time = time.time() logging.info("Checkpoint actual writing time: (%.3f sec)", end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path) if not asynchronous: self._last_checkpoint_step = step _save_fn() return if self._save_thread is not None: self._save_thread.join(timeout=0.1) if self._save_thread.is_alive(): logging.info("Saver thread still in progress, skipping checkpoint.") return self._last_checkpoint_step = step self._save_thread = threading.Thread(target=_save_fn) self._save_thread.start() def _get_saver(self): if self._saver is not None: return self._saver elif self._scaffold is not None: return self._scaffold.saver # Get saver from the SAVERS collection if present. collection_key = ops.GraphKeys.SAVERS savers = ops.get_collection(collection_key) if not savers: raise RuntimeError( "No items in collection {}. Please add a saver to the collection " "or provide a saver or scaffold.".format(collection_key)) elif len(savers) > 1: raise RuntimeError( "More than one item in collection {}. " "Please indicate which one to use by passing it to the constructor." .format(collection_key)) self._saver = savers[0] return savers[0]