# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ====================================== """Operations for handling session logging and shutdown notifications.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import threading import time from google.protobuf import text_format from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util _WATCHDOG = None class CoordinatorResetError(errors.AbortedError): """Raised when the monitored session should reset.""" def __init__(self): errors.AbortedError.__init__( self, None, None, 'Resetting session loop due to worker shutdown.') def _clone_session(session, graph=None): return session_lib.Session( target=session.sess_str, config=session._config, # pylint: disable=protected-access graph=graph if graph else session.graph) class WorkerHeartbeatManager(object): """Manages the status/heartbeat monitor for a set of workers.""" def __init__(self, session, devices, heartbeat_ops, request_placeholder): """Construct a new WorkerHeartbeatManager. (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) Args: session: `tf.compat.v1.Session`, session to use for heartbeat operations. devices: `list[string]` Set of devices to connect to. heartbeat_ops: `list[tf.Operation]` Heartbeat operations. request_placeholder: `tf.Placeholder[String]` Placeholder used to specify the WorkerHeartbeatRequest protocol buffer. """ self._session = session self._devices = devices self._ops = heartbeat_ops self._request_placeholder = request_placeholder @staticmethod def from_devices(session, devices): """Construct a heartbeat manager for the given devices.""" if not devices: logging.error('Trying to create heartbeat manager with no devices?') logging.info('Creating heartbeat manager for %s', devices) request_placeholder = array_ops.placeholder( name='worker_heartbeat_request', dtype=dtypes.string) heartbeat_ops = [] for device in devices: with ops.device(device): heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) return WorkerHeartbeatManager(session, devices, heartbeat_ops, request_placeholder) def num_workers(self): return len(self._devices) def configure(self, message): """Configure heartbeat manager for all devices. Args: message: `event_pb2.WorkerHeartbeatRequest` Returns: `None` """ logging.info('Configuring worker heartbeat: %s', text_format.MessageToString(message)) self._session.run(self._ops, {self._request_placeholder: message.SerializeToString()}) def ping(self, request=None, timeout_in_ms=5000): """Ping all workers, returning the parsed status results.""" if request is None: request = event_pb2.WorkerHeartbeatRequest() options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) results = self._session.run( self._ops, feed_dict={self._request_placeholder: request.SerializeToString()}, options=options) parsed_results = [ event_pb2.WorkerHeartbeatResponse.FromString(res_pb) for res_pb in results ] logging.debug('Ping results: %s', parsed_results) return parsed_results def lame_workers(self): """Ping all workers, returning manager containing lame workers (or None).""" ping_results = self.ping() lame_workers = [] for ping_response, device, op in zip(ping_results, self._devices, self._ops): if ping_response.health_status != event_pb2.OK: lame_workers.append((device, op)) if not lame_workers: return None bad_devices, bad_ops = zip(*lame_workers) return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, self._request_placeholder) def __repr__(self): return 'HeartbeatManager(%s)' % ','.join(self._devices) # Default timeout is set to allow other shutdown triggered operations (log # flushing etc) to finish before terminating the worker. def shutdown(self, wait_time_in_ms=60000, exit_code=0): """Shutdown all workers after `shutdown_timeout_secs`.""" logging.info('Shutting down %s.', self) req = event_pb2.WorkerHeartbeatRequest( watchdog_config=event_pb2.WatchdogConfig(timeout_ms=wait_time_in_ms), shutdown_mode=event_pb2.SHUTDOWN_AFTER_TIMEOUT, exit_code=event_pb2.RequestedExitCode(exit_code=exit_code)) self.configure(req) # Wait for workers to shutdown. sleep_sec = 10.0 + wait_time_in_ms / 1000 logging.info('Waiting %.2f seconds for worker shutdown.', sleep_sec) time.sleep(sleep_sec) def all_worker_devices(session): """Return a list of devices for each worker in the system.""" devices = session.list_devices() devices_that_support_heartbeats = [] for device in devices: name = device.name # Pick devices that have a TPU but target the attached CPU if ':TPU:0' in name and 'coordinator' not in name: devices_that_support_heartbeats.append(name.replace('TPU', 'CPU')) return devices_that_support_heartbeats class WatchdogManager(threading.Thread): """Configures worker watchdog timer and handles periodic pings. Usage: # Ping workers every minute, shutting down workers if they haven't received # a ping after 1 hour. watchdog_manager = WatchdogManager( ping_interval=60, shutdown_timeout=3600 ) # Use as a context manager, resetting watchdog on context exit: with watchdog_manager: session.run(...) # Or setup globally; watchdog will remain active until program exit. watchdog_manager.configure_and_run() """ def __init__(self, session, devices=None, ping_interval=60, shutdown_timeout=2 * 3600): """Initialize a watchdog manager. Args: session: Session connected to worker devices. A cloned session and graph will be created for managing worker pings. devices: Set of devices to monitor. If none, all workers will be monitored. ping_interval: Time, in seconds, between watchdog pings. shutdown_timeout: Time, in seconds, before watchdog timeout. """ threading.Thread.__init__(self) self.ping_interval = ping_interval self.shutdown_timeout = shutdown_timeout self.daemon = True self._config = session._config # pylint: disable=protected-access self._target = session.sess_str self._running = False self._devices = devices self._graph = None self._session = None self._worker_manager = None def _reset_manager(self, stopping=False): """Reset the graph, session and worker manager.""" self._graph = ops.Graph() self._session = session_lib.Session( target=self._target, graph=self._graph, config=self._config, ) if self._devices is None: self._devices = all_worker_devices(self._session) with self._graph.as_default(): self._worker_manager = WorkerHeartbeatManager.from_devices( self._session, self._devices) if stopping: timeout_ms = -1 shutdown_mode = event_pb2.NOT_CONFIGURED else: timeout_ms = self.shutdown_timeout * 1000 shutdown_mode = event_pb2.WAIT_FOR_COORDINATOR self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms), shutdown_mode=shutdown_mode)) def configure_and_run(self): logging.info( 'Enabling watchdog timer with %d second timeout ' 'and %d second ping interval.', self.shutdown_timeout, self.ping_interval) self._reset_manager() self._running = True self.start() def stop(self): logging.info('Stopping worker watchdog.') self._reset_manager(stopping=True) self._running = False self.join() def __enter__(self): self.configure_and_run() def __exit__(self, exc_type, exc_val, exc_tb): self.stop() def run(self): # Don't fetch logs or adjust timing: just ping the watchdog. # # If we hit an exception, reset our session as it is likely broken. while self._running: try: self._worker_manager.ping(request=None) time.sleep(self.ping_interval) except errors.OpError as e: # Catch any TF errors that occur so we don't stop sending heartbeats logging.debug('Caught error while sending heartbeat: %s', e) self._reset_manager() def start_worker_watchdog(session, devices=None, ping_interval=60, shutdown_timeout=3600): """Start global worker watchdog to shutdown workers on coordinator exit.""" global _WATCHDOG if _WATCHDOG is None: # Ensure we can send a few pings before we timeout! ping_interval = min(shutdown_timeout / 10., ping_interval) _WATCHDOG = WatchdogManager(session, devices, ping_interval, shutdown_timeout) _WATCHDOG.configure_and_run() def stop_worker_watchdog(): """Stop global worker watchdog.""" global _WATCHDOG if _WATCHDOG is not None: _WATCHDOG.stop() _WATCHDOG = None class GracefulShutdownHook(session_run_hook.SessionRunHook): """Session hook that watches for shutdown events. If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a SystemShutdown exception is raised to terminate the main session. If `saver` is None the `SAVERS` collection will be read to find a saver. `on_shutdown_hooks` is an optional list of functions that should be called after checkpointing. The function is called with (`run_context`, `all_workers`, `lame_workers`). If `heartbeat_group` is not specified, it will default to all CPU workers in the system. """ def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): self._saver = saver self._checkpoint_prefix = checkpoint_prefix self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] # Worker heartbeats are managed independently of the main training graph. self._graph = ops.Graph() self._workers = None self._session = None self._heartbeat_supported = False def after_create_session(self, training_session, coord): # pylint: disable=unused-argument # N.B. We have to pull the global step here to avoid it being unavailable # at checkpoint time; the graph has been frozen at that point. if training_util.get_global_step() is None and self.saver() is not None: raise ValueError( 'Saver defined but no global step. Run `get_or_create_global_step()`' ' in your model definition to allow checkpointing.') with self._graph.as_default(): logging.info('Installing graceful shutdown hook.') self._session = _clone_session(training_session, self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 if self._heartbeat_supported: try: self._workers.configure( event_pb2.WorkerHeartbeatRequest( shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) except errors.InvalidArgumentError: logging.warn( 'TPU device does not support heartbeats. Failure ' 'handling will be disabled.') self._heartbeat_supported = False else: logging.warn( 'No workers support heartbeats. Failure handling will be disabled.') def saver(self): if self._saver: return self._saver savers = ops.get_collection(ops.GraphKeys.SAVERS) if not savers: return None if not isinstance(savers, list): return savers if len(savers) > 1: logging.error( 'Multiple savers in the SAVERS collection. On-demand checkpointing ' 'will be disabled. Pass an explicit `saver` to the constructor to ' 'override this behavior.') return None return savers[0] def after_run(self, run_context, run_values): del run_values if not self._heartbeat_supported: return lame_workers = self._workers.lame_workers() if lame_workers: logging.info('ShutdownHook: lame workers found: %s', lame_workers) if self.saver(): logging.info('ShutdownHook: saving checkpoint to %s', self._checkpoint_prefix) self.saver().save( run_context.session, self._checkpoint_prefix, global_step=training_util.get_global_step(), write_state=True, ) else: logging.info('ShutdownHook: no Saver defined.') for fn in self._on_shutdown_hooks: fn(run_context, self._workers, lame_workers) class ResetComputation(object): """Hook to reset a TPUEstimator computation loop. This hook shuts down all workers and resets the monitored session loop by throwing a CoordinatorResetError. """ def __init__(self): pass def __call__(self, run_context, all_workers, lame_workers): del run_context, lame_workers all_workers.shutdown(exit_code=42) logging.info('Resetting coordinator.') raise CoordinatorResetError() class ShutdownLameWorkers(object): """Shutdown lamed workers. Processing will continue normally (typically by waiting for the down workers to be restarted). """ def __init__(self): pass def __call__(self, run_context, all_workers, lame_workers): lame_workers.shutdown(exit_code=42) class ShutdownAllWorkers(object): """Shutdown all workers. Processing will continue normally (typically by waiting for the down workers to be restarted). """ def __init__(self): pass def __call__(self, run_context, all_workers, lame_workers): all_workers.shutdown(exit_code=42)