# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Precision--recall curves and TensorFlow operations to create them. NOTE: This module is in beta, and its API is subject to change, but the data that it stores to disk will be supported forever. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorboard.plugins.pr_curve import metadata # A value that we use as the minimum value during division of counts to prevent # division by 0. 1.0 does not work: Certain weights could cause counts below 1. _MINIMUM_COUNT = 1e-7 # The default number of thresholds. _DEFAULT_NUM_THRESHOLDS = 201 def op( name, labels, predictions, num_thresholds=None, weights=None, display_name=None, description=None, collections=None, ): """Create a PR curve summary op for a single binary classifier. Computes true/false positive/negative values for the given `predictions` against the ground truth `labels`, against a list of evenly distributed threshold values in `[0, 1]` of length `num_thresholds`. Each number in `predictions`, a float in `[0, 1]`, is compared with its corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn value at each threshold. This is then multiplied with `weights` which can be used to reweight certain values, or more commonly used for masking values. Args: name: A tag attached to the summary. Used by TensorBoard for organization. labels: The ground truth values. A Tensor of `bool` values with arbitrary shape. predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. Dimensions must match those of `labels`. num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for. Should be `>= 2`. This value should be a constant integer value, not a Tensor that stores an integer. weights: Optional float32 `Tensor`. Individual counts are multiplied by this value. This tensor must be either the same shape as or broadcastable to the `labels` tensor. display_name: Optional name for this summary in TensorBoard, as a constant `str`. Defaults to `name`. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[Graph Keys.SUMMARIES]`. Returns: A summary operation for use in a TensorFlow graph. The float32 tensor produced by the summary operation is of dimension (6, num_thresholds). The first dimension (of length 6) is of the order: true positives, false positives, true negatives, false negatives, precision, recall. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if num_thresholds is None: num_thresholds = _DEFAULT_NUM_THRESHOLDS if weights is None: weights = 1.0 dtype = predictions.dtype with tf.name_scope(name, values=[labels, predictions, weights]): tf.assert_type(labels, tf.bool) # We cast to float to ensure we have 0.0 or 1.0. f_labels = tf.cast(labels, dtype) # Ensure predictions are all in range [0.0, 1.0]. predictions = tf.minimum(1.0, tf.maximum(0.0, predictions)) # Get weighted true/false labels. true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Before we begin, flatten predictions. predictions = tf.reshape(predictions, [-1]) # Shape the labels so they are broadcast-able for later multiplication. true_labels = tf.reshape(true_labels, [-1, 1]) false_labels = tf.reshape(false_labels, [-1, 1]) # To compute TP/FP/TN/FN, we are measuring a binary classifier # C(t) = (predictions >= t) # at each threshold 't'. So we have # TP(t) = sum( C(t) * true_labels ) # FP(t) = sum( C(t) * false_labels ) # # But, computing C(t) requires computation for each t. To make it fast, # observe that C(t) is a cumulative integral, and so if we have # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} # where n = num_thresholds, and if we can compute the bucket function # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) # then we get # C(t_i) = sum( B(j), j >= i ) # which is the reversed cumulative sum in tf.cumsum(). # # We can compute B(i) efficiently by taking advantage of the fact that # our thresholds are evenly distributed, in that # width = 1.0 / (num_thresholds - 1) # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] # Given a prediction value p, we can map it to its bucket by # bucket_index(p) = floor( p * (num_thresholds - 1) ) # so we can use tf.scatter_add() to update the buckets in one pass. # Compute the bucket indices for each prediction value. bucket_indices = tf.cast( tf.floor(predictions * (num_thresholds - 1)), tf.int32 ) # Bucket predictions. tp_buckets = tf.reduce_sum( input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) * true_labels, axis=0, ) fp_buckets = tf.reduce_sum( input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) * false_labels, axis=0, ) # Set up the cumulative sums to compute the actual metrics. tp = tf.cumsum(tp_buckets, reverse=True, name="tp") fp = tf.cumsum(fp_buckets, reverse=True, name="fp") # fn = sum(true_labels) - tp # = sum(tp_buckets) - tp # = tp[0] - tp # Similarly, # tn = fp[0] - fp tn = fp[0] - fp fn = tp[0] - tp precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) return _create_tensor_summary( name, tp, fp, tn, fn, precision, recall, num_thresholds, display_name, description, collections, ) def pb( name, labels, predictions, num_thresholds=None, weights=None, display_name=None, description=None, ): """Create a PR curves summary protobuf. Arguments: name: A name for the generated node. Will also serve as a series name in TensorBoard. labels: The ground truth values. A bool numpy array. predictions: A float32 numpy array whose values are in the range `[0, 1]`. Dimensions must match those of `labels`. num_thresholds: Optional number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for. When provided, should be an int of value at least 2. Defaults to 201. weights: Optional float or float32 numpy array. Individual counts are multiplied by this value. This tensor must be either the same shape as or broadcastable to the `labels` numpy array. display_name: Optional name for this summary in TensorBoard, as a `str`. Defaults to `name`. description: Optional long-form description for this summary, as a `str`. Markdown is supported. Defaults to empty. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf # noqa: F401 if num_thresholds is None: num_thresholds = _DEFAULT_NUM_THRESHOLDS if weights is None: weights = 1.0 # Compute bins of true positives and false positives. bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) float_labels = labels.astype(np.float) histogram_range = (0, num_thresholds - 1) tp_buckets, _ = np.histogram( bucket_indices, bins=num_thresholds, range=histogram_range, weights=float_labels * weights, ) fp_buckets, _ = np.histogram( bucket_indices, bins=num_thresholds, range=histogram_range, weights=(1.0 - float_labels) * weights, ) # Obtain the reverse cumulative sum. tp = np.cumsum(tp_buckets[::-1])[::-1] fp = np.cumsum(fp_buckets[::-1])[::-1] tn = fp[0] - fp fn = tp[0] - tp precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) return raw_data_pb( name, true_positive_counts=tp, false_positive_counts=fp, true_negative_counts=tn, false_negative_counts=fn, precision=precision, recall=recall, num_thresholds=num_thresholds, display_name=display_name, description=description, ) def streaming_op( name, labels, predictions, num_thresholds=None, weights=None, metrics_collections=None, updates_collections=None, display_name=None, description=None, ): """Computes a precision-recall curve summary across batches of data. This function is similar to op() above, but can be used to compute the PR curve across multiple batches of labels and predictions, in the same style as the metrics found in tf.metrics. This function creates multiple local variables for storing true positives, true negative, etc. accumulated over each batch of data, and uses these local variables for computing the final PR curve summary. These variables can be updated with the returned update_op. Args: name: A tag attached to the summary. Used by TensorBoard for organization. labels: The ground truth values, a `Tensor` whose dimensions must match `predictions`. Will be cast to `bool`. predictions: A floating point `Tensor` of arbitrary shape and whose values are in the range `[0, 1]`. num_thresholds: The number of evenly spaced thresholds to generate for computing the PR curve. Defaults to 201. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). metrics_collections: An optional list of collections that `auc` should be added to. updates_collections: An optional list of collections that `update_op` should be added to. display_name: Optional name for this summary in TensorBoard, as a constant `str`. Defaults to `name`. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. Returns: pr_curve: A string `Tensor` containing a single value: the serialized PR curve Tensor summary. The summary contains a float32 `Tensor` of dimension (6, num_thresholds). The first dimension (of length 6) is of the order: true positives, false positives, true negatives, false negatives, precision, recall. update_op: An operation that updates the summary with the latest data. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if num_thresholds is None: num_thresholds = _DEFAULT_NUM_THRESHOLDS thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)] with tf.name_scope(name, values=[labels, predictions, weights]): tp, update_tp = tf.metrics.true_positives_at_thresholds( labels=labels, predictions=predictions, thresholds=thresholds, weights=weights, ) fp, update_fp = tf.metrics.false_positives_at_thresholds( labels=labels, predictions=predictions, thresholds=thresholds, weights=weights, ) tn, update_tn = tf.metrics.true_negatives_at_thresholds( labels=labels, predictions=predictions, thresholds=thresholds, weights=weights, ) fn, update_fn = tf.metrics.false_negatives_at_thresholds( labels=labels, predictions=predictions, thresholds=thresholds, weights=weights, ) def compute_summary(tp, fp, tn, fn, collections): precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) return _create_tensor_summary( name, tp, fp, tn, fn, precision, recall, num_thresholds, display_name, description, collections, ) pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections) update_op = tf.group(update_tp, update_fp, update_tn, update_fn) if updates_collections: for collection in updates_collections: tf.add_to_collection(collection, update_op) return pr_curve, update_op def raw_data_op( name, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, num_thresholds=None, display_name=None, description=None, collections=None, ): """Create an op that collects data for visualizing PR curves. Unlike the op above, this one avoids computing precision, recall, and the intermediate counts. Instead, it accepts those tensors as arguments and relies on the caller to ensure that the calculations are correct (and the counts yield the provided precision and recall values). This op is useful when a caller seeks to compute precision and recall differently but still use the PR curves plugin. Args: name: A tag attached to the summary. Used by TensorBoard for organization. true_positive_counts: A rank-1 tensor of true positive counts. Must contain `num_thresholds` elements and be castable to float32. Values correspond to thresholds that increase from left to right (from 0 to 1). false_positive_counts: A rank-1 tensor of false positive counts. Must contain `num_thresholds` elements and be castable to float32. Values correspond to thresholds that increase from left to right (from 0 to 1). true_negative_counts: A rank-1 tensor of true negative counts. Must contain `num_thresholds` elements and be castable to float32. Values correspond to thresholds that increase from left to right (from 0 to 1). false_negative_counts: A rank-1 tensor of false negative counts. Must contain `num_thresholds` elements and be castable to float32. Values correspond to thresholds that increase from left to right (from 0 to 1). precision: A rank-1 tensor of precision values. Must contain `num_thresholds` elements and be castable to float32. Values correspond to thresholds that increase from left to right (from 0 to 1). recall: A rank-1 tensor of recall values. Must contain `num_thresholds` elements and be castable to float32. Values correspond to thresholds that increase from left to right (from 0 to 1). num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for. Should be `>= 2`. This value should be a constant integer value, not a Tensor that stores an integer. display_name: Optional name for this summary in TensorBoard, as a constant `str`. Defaults to `name`. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[Graph Keys.SUMMARIES]`. Returns: A summary operation for use in a TensorFlow graph. See docs for the `op` method for details on the float32 tensor produced by this summary. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf with tf.name_scope( name, values=[ true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, ], ): return _create_tensor_summary( name, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, num_thresholds, display_name, description, collections, ) def raw_data_pb( name, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, num_thresholds=None, display_name=None, description=None, ): """Create a PR curves summary protobuf from raw data values. Args: name: A tag attached to the summary. Used by TensorBoard for organization. true_positive_counts: A rank-1 numpy array of true positive counts. Must contain `num_thresholds` elements and be castable to float32. false_positive_counts: A rank-1 numpy array of false positive counts. Must contain `num_thresholds` elements and be castable to float32. true_negative_counts: A rank-1 numpy array of true negative counts. Must contain `num_thresholds` elements and be castable to float32. false_negative_counts: A rank-1 numpy array of false negative counts. Must contain `num_thresholds` elements and be castable to float32. precision: A rank-1 numpy array of precision values. Must contain `num_thresholds` elements and be castable to float32. recall: A rank-1 numpy array of recall values. Must contain `num_thresholds` elements and be castable to float32. num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for. Should be an int `>= 2`. display_name: Optional name for this summary in TensorBoard, as a `str`. Defaults to `name`. description: Optional long-form description for this summary, as a `str`. Markdown is supported. Defaults to empty. Returns: A summary operation for use in a TensorFlow graph. See docs for the `op` method for details on the float32 tensor produced by this summary. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf if display_name is None: display_name = name summary_metadata = metadata.create_summary_metadata( display_name=display_name if display_name is not None else name, description=description or "", num_thresholds=num_thresholds, ) tf_summary_metadata = tf.SummaryMetadata.FromString( summary_metadata.SerializeToString() ) summary = tf.Summary() data = np.stack( ( true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, ) ) tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32) summary.value.add( tag="%s/pr_curves" % name, metadata=tf_summary_metadata, tensor=tensor ) return summary def _create_tensor_summary( name, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, num_thresholds=None, display_name=None, description=None, collections=None, ): """A private helper method for generating a tensor summary. We use a helper method instead of having `op` directly call `raw_data_op` to prevent the scope of `raw_data_op` from being embedded within `op`. Arguments are the same as for raw_data_op. Returns: A tensor summary that collects data for PR curves. """ # TODO(nickfelt): remove on-demand imports once dep situation is fixed. import tensorflow.compat.v1 as tf # Store the number of thresholds within the summary metadata because # that value is constant for all pr curve summaries with the same tag. summary_metadata = metadata.create_summary_metadata( display_name=display_name if display_name is not None else name, description=description or "", num_thresholds=num_thresholds, ) # Store values within a tensor. We store them in the order: # true positives, false positives, true negatives, false # negatives, precision, and recall. combined_data = tf.stack( [ tf.cast(true_positive_counts, tf.float32), tf.cast(false_positive_counts, tf.float32), tf.cast(true_negative_counts, tf.float32), tf.cast(false_negative_counts, tf.float32), tf.cast(precision, tf.float32), tf.cast(recall, tf.float32), ] ) return tf.summary.tensor_summary( name="pr_curves", tensor=combined_data, collections=collections, summary_metadata=summary_metadata, )