# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Keras integration for TensorBoard hparams. Most users should use `tensorboard.plugins.hparams.api` to access this module's contents. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import summary from tensorboard.plugins.hparams import summary_v2 class Callback(tf.keras.callbacks.Callback): """Callback for logging hyperparameters to TensorBoard. NOTE: This callback only works in TensorFlow eager mode. """ def __init__(self, writer, hparams, trial_id=None): """Create a callback for logging hyperparameters to TensorBoard. As with the standard `tf.keras.callbacks.TensorBoard` class, each callback object is valid for only one call to `model.fit`. Args: writer: The `SummaryWriter` object to which hparams should be written, or a logdir (as a `str`) to be passed to `tf.summary.create_file_writer` to create such a writer. hparams: A `dict` mapping hyperparameters to the values used in this session. Keys should be the names of `HParam` objects used in an experiment, or the `HParam` objects themselves. Values should be Python `bool`, `int`, `float`, or `string` values, depending on the type of the hyperparameter. trial_id: An optional `str` ID for the set of hyperparameter values used in this trial. Defaults to a hash of the hyperparameters. Raises: ValueError: If two entries in `hparams` share the same hyperparameter name. """ # Defer creating the actual summary until we write it, so that the # timestamp is correct. But create a "dry-run" first to fail fast in # case the `hparams` are invalid. self._hparams = dict(hparams) self._trial_id = trial_id summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id) if writer is None: raise TypeError( "writer must be a `SummaryWriter` or `str`, not None" ) elif isinstance(writer, str): self._writer = tf.compat.v2.summary.create_file_writer(writer) else: self._writer = writer def _get_writer(self): if self._writer is None: raise RuntimeError( "hparams Keras callback cannot be reused across training sessions" ) if not tf.executing_eagerly(): raise RuntimeError( "hparams Keras callback only supported in TensorFlow eager mode" ) return self._writer def on_train_begin(self, logs=None): del logs # unused with self._get_writer().as_default(): summary_v2.hparams(self._hparams, trial_id=self._trial_id) def on_train_end(self, logs=None): del logs # unused with self._get_writer().as_default(): pb = summary.session_end_pb(api_pb2.STATUS_SUCCESS) raw_pb = pb.SerializeToString() tf.compat.v2.summary.experimental.write_raw_pb(raw_pb, step=0) self._writer = None