# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import six from werkzeug import wrappers from tensorboard import plugin_util from tensorboard.data import provider from tensorboard.backend import http_util from tensorboard.plugins import base_plugin from tensorboard.plugins.pr_curve import metadata _DEFAULT_DOWNSAMPLING = 100 # PR curves per time series class PrCurvesPlugin(base_plugin.TBPlugin): """A plugin that serves PR curves for individual classes.""" plugin_name = metadata.PLUGIN_NAME def __init__(self, context): """Instantiates a PrCurvesPlugin. Args: context: A base_plugin.TBContext instance. A magic container that TensorBoard uses to make objects available to the plugin. """ self._data_provider = context.data_provider self._downsample_to = (context.sampling_hints or {}).get( metadata.PLUGIN_NAME, _DEFAULT_DOWNSAMPLING ) @wrappers.Request.application def pr_curves_route(self, request): """A route that returns a JSON mapping between runs and PR curve data. Returns: Given a tag and a comma-separated list of runs (both stored within GET parameters), fetches a JSON object that maps between run name and objects containing data required for PR curves for that run. Runs that either cannot be found or that lack tags will be excluded from the response. """ ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) runs = request.args.getlist("run") if not runs: return http_util.Respond( request, "No runs provided when fetching PR curve data", 400 ) tag = request.args.get("tag") if not tag: return http_util.Respond( request, "No tag provided when fetching PR curve data", 400 ) try: response = http_util.Respond( request, self.pr_curves_impl(ctx, experiment, runs, tag), "application/json", ) except ValueError as e: return http_util.Respond(request, str(e), "text/plain", 400) return response def pr_curves_impl(self, ctx, experiment, runs, tag): """Creates the JSON object for the PR curves response for a run-tag combo. Arguments: runs: A list of runs to fetch the curves for. tag: The tag to fetch the curves for. Raises: ValueError: If no PR curves could be fetched for a run and tag. Returns: The JSON object for the PR curves route response. """ response_mapping = {} rtf = provider.RunTagFilter(runs, [tag]) read_result = self._data_provider.read_tensors( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, run_tag_filter=rtf, downsample=self._downsample_to, ) for run in runs: data = read_result.get(run, {}).get(tag) if data is None: raise ValueError( "No PR curves could be found for run %r and tag %r" % (run, tag) ) response_mapping[run] = [self._process_datum(d) for d in data] return response_mapping @wrappers.Request.application def tags_route(self, request): """A route (HTTP handler) that returns a response with tags. Returns: A response that contains a JSON object. The keys of the object are all the runs. Each run is mapped to a (potentially empty) dictionary whose keys are tags associated with run and whose values are metadata (dictionaries). The metadata dictionaries contain 2 keys: - displayName: For the display name used atop visualizations in TensorBoard. - description: The description that appears near visualizations upon the user hovering over a certain icon. """ ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) return http_util.Respond( request, self.tags_impl(ctx, experiment), "application/json" ) def tags_impl(self, ctx, experiment): """Creates the JSON object for the tags route response. Returns: The JSON object for the tags route response. """ mapping = self._data_provider.list_tensors( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME ) result = {run: {} for run in mapping} for (run, tag_to_time_series) in six.iteritems(mapping): for (tag, time_series) in tag_to_time_series.items(): result[run][tag] = { "displayName": time_series.display_name, "description": plugin_util.markdown_to_safe_html( time_series.description ), } return result def get_plugin_apps(self): """Gets all routes offered by the plugin. Returns: A dictionary mapping URL path to route that handles it. """ return { "/tags": self.tags_route, "/pr_curves": self.pr_curves_route, } def is_active(self): return False # `list_plugins` as called by TB core suffices def frontend_metadata(self): return base_plugin.FrontendMetadata( element_name="tf-pr-curve-dashboard", tab_name="PR Curves", ) def _process_datum(self, datum): """Converts a TensorDatum into a dict that encapsulates information on it. Args: datum: The TensorDatum to convert. Returns: A JSON-able dictionary of PR curve data for 1 step. """ return self._make_pr_entry(datum.step, datum.wall_time, datum.numpy) def _make_pr_entry(self, step, wall_time, data_array): """Creates an entry for PR curve data. Each entry corresponds to 1 step. Args: step: The step. wall_time: The wall time. data_array: A numpy array of PR curve data stored in the summary format. Returns: A PR curve entry. """ # Trim entries for which TP + FP = 0 (precision is undefined) at the tail of # the data. true_positives = [ int(v) for v in data_array[metadata.TRUE_POSITIVES_INDEX] ] false_positives = [ int(v) for v in data_array[metadata.FALSE_POSITIVES_INDEX] ] tp_index = metadata.TRUE_POSITIVES_INDEX fp_index = metadata.FALSE_POSITIVES_INDEX positives = data_array[[tp_index, fp_index], :].astype(int).sum(axis=0) end_index_inclusive = len(positives) - 1 while end_index_inclusive > 0 and positives[end_index_inclusive] == 0: end_index_inclusive -= 1 end_index = end_index_inclusive + 1 num_thresholds = data_array.shape[1] thresholds = (np.arange(1, end_index + 1) / num_thresholds).tolist() return { "wall_time": wall_time, "step": step, "precision": data_array[ metadata.PRECISION_INDEX, :end_index ].tolist(), "recall": data_array[metadata.RECALL_INDEX, :end_index].tolist(), "true_positives": true_positives[:end_index], "false_positives": false_positives[:end_index], "true_negatives": [ int(v) for v in data_array[metadata.TRUE_NEGATIVES_INDEX][:end_index] ], "false_negatives": [ int(v) for v in data_array[metadata.FALSE_NEGATIVES_INDEX][:end_index] ], "thresholds": thresholds, }