# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Shared utils among inference plugins.""" from __future__ import division from __future__ import print_function import collections import copy import json import math from absl import logging import numpy as np import tensorflow as tf from google.protobuf import json_format from six import binary_type, string_types, integer_types from six import iteritems from six.moves import zip # pylint: disable=redefined-builtin from inspect import signature from tensorboard_plugin_wit._utils import common_utils from tensorboard_plugin_wit._utils import platform_utils from tensorboard_plugin_wit._vendor.tensorflow_serving.apis import classification_pb2 from tensorboard_plugin_wit._vendor.tensorflow_serving.apis import inference_pb2 from tensorboard_plugin_wit._vendor.tensorflow_serving.apis import regression_pb2 class VizParams(object): """Light-weight class for holding UI state. Attributes: x_min: The minimum value to use to generate mutants for the feature (as specified the user on the UI). x_max: The maximum value to use to generate mutants for the feature (as specified the user on the UI). examples: A list of examples to scan in order to generate statistics for mutants. num_mutants: Int number of mutants to generate per chart. feature_index_pattern: String that specifies a restricted set of indices of the feature to generate mutants for (useful for features that is a long repeated field. See `convert_pattern_to_indices` for more details. """ def __init__(self, x_min, x_max, examples, num_mutants, feature_index_pattern): """Inits VizParams may raise InvalidUserInputError for bad user inputs.""" def to_float_or_none(x): try: return float(x) except (ValueError, TypeError): return None def to_int(x): try: return int(x) except (ValueError, TypeError) as e: raise common_utils.InvalidUserInputError(e) def convert_pattern_to_indices(pattern): """Converts a printer-page-style pattern and returns a list of indices. Args: pattern: A printer-page-style pattern with only numeric characters, commas, dashes, and optionally spaces. For example, a pattern of '0,2,4-6' would yield [0, 2, 4, 5, 6]. Returns: A list of indices represented by the pattern. """ pieces = [token.strip() for token in pattern.split(',')] indices = [] for piece in pieces: if '-' in piece: lower, upper = [int(x.strip()) for x in piece.split('-', 1)] indices.extend(range(lower, upper + 1)) else: indices.append(int(piece.strip())) return sorted(indices) self.x_min = to_float_or_none(x_min) self.x_max = to_float_or_none(x_max) self.examples = examples self.num_mutants = to_int(num_mutants) # By default, there are no specific user-requested feature indices. self.feature_indices = [] if feature_index_pattern: try: self.feature_indices = convert_pattern_to_indices( feature_index_pattern) except ValueError as e: # If the user-requested range is invalid, use the default range. pass class OriginalFeatureList(object): """Light-weight class for holding the original values in the example. Should not be created by hand, but rather generated via `parse_original_feature_from_example`. Just used to hold inferred info about the example. Attributes: feature_name: String name of the feature. original_value: The value of the feature in the original example. feature_type: One of ['int64_list', 'float_list']. Raises: ValueError: If OriginalFeatureList fails init validation. """ def __init__(self, feature_name, original_value, feature_type): """Inits OriginalFeatureList.""" self.feature_name = feature_name self.original_value = [ ensure_not_binary(value) for value in original_value] self.feature_type = feature_type # Derived attributes. self.length = sum(1 for _ in original_value) class MutantFeatureValue(object): """Light-weight class for holding mutated values in the example. Should not be created by hand but rather generated via `make_mutant_features`. Used to represent a "mutant example": an example that is mostly identical to the user-provided original example, but has one feature that is different. Attributes: original_feature: An `OriginalFeatureList` object representing the feature to create mutants for. index: The index of the feature to create mutants for. The feature can be a repeated field, and we want to plot mutations of its various indices. mutant_value: The proposed mutant value for the given index. Raises: ValueError: If MutantFeatureValue fails init validation. """ def __init__(self, original_feature, index, mutant_value): """Inits MutantFeatureValue.""" if not isinstance(original_feature, OriginalFeatureList): raise ValueError( 'original_feature should be `OriginalFeatureList`, but had ' 'unexpected type: {}'.format(type(original_feature))) self.original_feature = original_feature if index is not None and not isinstance(index, integer_types): raise ValueError( 'index should be None or int, but had unexpected type: {}'.format( type(index))) self.index = index self.mutant_value = (mutant_value.encode() if isinstance(mutant_value, string_types) else mutant_value) class ServingBundle(object): """Light-weight class for holding info to make the inference request. Attributes: inference_address: An address (such as "hostname:port") to send inference requests to. model_name: The Servo model name. model_type: One of ['classification', 'regression']. model_version: The version number of the model as a string. If set to an empty string, the latest model will be used. signature: The signature of the model to infer. If set to an empty string, the default signuature will be used. use_predict: If true then use the servo Predict API as opposed to Classification or Regression. predict_input_tensor: The name of the input tensor to parse when using the Predict API. predict_output_tensor: The name of the output tensor to parse when using the Predict API. estimator: An estimator to use instead of calling an external model. feature_spec: A feature spec for use with the estimator. custom_predict_fn: A custom prediction function. Raises: ValueError: If ServingBundle fails init validation. """ def __init__(self, inference_address, model_name, model_type, model_version, signature, use_predict, predict_input_tensor, predict_output_tensor, estimator=None, feature_spec=None, custom_predict_fn=None): """Inits ServingBundle.""" if not isinstance(inference_address, string_types): raise ValueError('Invalid inference_address has type: {}'.format( type(inference_address))) # Clean the inference_address so that SmartStub likes it. self.inference_address = inference_address.replace('http://', '').replace( 'https://', '') if not isinstance(model_name, string_types): raise ValueError('Invalid model_name has type: {}'.format( type(model_name))) self.model_name = model_name if model_type not in ['classification', 'regression']: raise ValueError('Invalid model_type: {}'.format(model_type)) self.model_type = model_type self.model_version = int(model_version) if model_version else None self.signature = signature if signature else None self.use_predict = use_predict self.predict_input_tensor = predict_input_tensor self.predict_output_tensor = predict_output_tensor self.estimator = estimator self.feature_spec = feature_spec self.custom_predict_fn = custom_predict_fn def ensure_not_binary(value): """Return non-binary version of value.""" try: return value.decode() if isinstance(value, binary_type) else value except UnicodeDecodeError: # If the value cannot be decoded as a string (such as an encoded image), # then just return the value. return value def proto_value_for_feature(example, feature_name): """Get the value of a feature from Example regardless of feature type.""" feature = get_example_features(example)[feature_name] if feature is None: raise ValueError('Feature {} is not on example proto.'.format(feature_name)) feature_type = feature.WhichOneof('kind') if feature_type is None: raise ValueError('Feature {} on example proto has no declared type.'.format( feature_name)) return getattr(feature, feature_type).value def parse_original_feature_from_example(example, feature_name): """Returns an `OriginalFeatureList` for the specified feature_name. Args: example: An example. feature_name: A string feature name. Returns: A filled in `OriginalFeatureList` object representing the feature. """ feature = get_example_features(example)[feature_name] feature_type = feature.WhichOneof('kind') original_value = proto_value_for_feature(example, feature_name) return OriginalFeatureList(feature_name, original_value, feature_type) def wrap_inference_results(inference_result_proto): """Returns packaged inference results from the provided proto. Args: inference_result_proto: The classification or regression response proto. Returns: An InferenceResult proto with the result from the response. """ inference_proto = inference_pb2.InferenceResult() if isinstance(inference_result_proto, classification_pb2.ClassificationResponse): inference_proto.classification_result.CopyFrom( inference_result_proto.result) elif isinstance(inference_result_proto, regression_pb2.RegressionResponse): inference_proto.regression_result.CopyFrom(inference_result_proto.result) return inference_proto def get_numeric_feature_names(example): """Returns a list of feature names for float and int64 type features. Args: example: An example. Returns: A list of strings of the names of numeric features. """ numeric_features = ('float_list', 'int64_list') features = get_example_features(example) return sorted([ feature_name for feature_name in features if features[feature_name].WhichOneof('kind') in numeric_features ]) def get_categorical_feature_names(example): """Returns a list of feature names for byte type features. Args: example: An example. Returns: A list of categorical feature names (e.g. ['education', 'marital_status'] ) """ features = get_example_features(example) return sorted([ feature_name for feature_name in features if features[feature_name].WhichOneof('kind') == 'bytes_list' ]) def get_numeric_features_to_observed_range(examples): """Returns numerical features and their observed ranges. Args: examples: Examples to read to get ranges. Returns: A dict mapping feature_name -> {'observedMin': 'observedMax': } dicts, with a key for each numerical feature. """ observed_features = collections.defaultdict(list) # name -> [value, ] for example in examples: for feature_name in get_numeric_feature_names(example): original_feature = parse_original_feature_from_example( example, feature_name) observed_features[feature_name].extend(original_feature.original_value) return { feature_name: { 'observedMin': min(feature_values), 'observedMax': max(feature_values), } for feature_name, feature_values in iteritems(observed_features) } def get_categorical_features_to_sampling(examples, top_k): """Returns categorical features and a sampling of their most-common values. The results of this slow function are used by the visualization repeatedly, so the results are cached. Args: examples: Examples to read to get feature samples. top_k: Max number of samples to return per feature. Returns: A dict of feature_name -> {'samples': ['Married-civ-spouse', 'Never-married', 'Divorced']}. There is one key for each categorical feature. Currently, the inner dict just has one key, but this structure leaves room for further expansion, and mirrors the structure used by `get_numeric_features_to_observed_range`. """ observed_features = collections.defaultdict(list) # name -> [value, ] for example in examples: for feature_name in get_categorical_feature_names(example): original_feature = parse_original_feature_from_example( example, feature_name) observed_features[feature_name].extend(original_feature.original_value) result = {} for feature_name, feature_values in sorted(iteritems(observed_features)): samples = [ word for word, count in collections.Counter(feature_values).most_common( top_k) if count > 1 ] if samples: result[feature_name] = {'samples': samples} return result def make_mutant_features(original_feature, index_to_mutate, viz_params): """Return a list of `MutantFeatureValue`s that are variants of original.""" lower = viz_params.x_min upper = viz_params.x_max examples = viz_params.examples num_mutants = viz_params.num_mutants if original_feature.feature_type == 'float_list': return [ MutantFeatureValue(original_feature, index_to_mutate, value) for value in np.linspace(lower, upper, num_mutants) ] elif original_feature.feature_type == 'int64_list': mutant_values = np.linspace(int(lower), int(upper), num_mutants).astype(int).tolist() # Remove duplicates that can occur due to integer constraint. mutant_values = sorted(set(mutant_values)) return [ MutantFeatureValue(original_feature, index_to_mutate, value) for value in mutant_values ] elif original_feature.feature_type == 'bytes_list': feature_to_samples = get_categorical_features_to_sampling( examples, num_mutants) # `mutant_values` looks like: # [['Married-civ-spouse'], ['Never-married'], ['Divorced'], ['Separated']] mutant_values = feature_to_samples[original_feature.feature_name]['samples'] return [ MutantFeatureValue(original_feature, None, value) for value in mutant_values ] else: raise ValueError('Malformed original feature had type of: ' + original_feature.feature_type) def make_mutant_tuples(example_protos, original_feature, index_to_mutate, viz_params): """Return a list of `MutantFeatureValue`s and a list of mutant Examples. Args: example_protos: The examples to mutate. original_feature: A `OriginalFeatureList` that encapsulates the feature to mutate. index_to_mutate: The index of the int64_list or float_list to mutate. viz_params: A `VizParams` object that contains the UI state of the request. Returns: A list of `MutantFeatureValue`s and a list of mutant examples. """ mutant_features = make_mutant_features(original_feature, index_to_mutate, viz_params) mutant_examples = [] for example_proto in example_protos: for mutant_feature in mutant_features: copied_example = copy.deepcopy(example_proto) feature_name = mutant_feature.original_feature.feature_name try: feature_list = proto_value_for_feature(copied_example, feature_name) if index_to_mutate is None: new_values = mutant_feature.mutant_value else: new_values = list(feature_list) new_values[index_to_mutate] = mutant_feature.mutant_value del feature_list[:] feature_list.extend(new_values) mutant_examples.append(copied_example) except (ValueError, IndexError): # If the mutant value can't be set, still add the example to the # mutant_example even though no change was made. This is necessary to # allow for computation of global PD plots when not all examples have # the same number of feature values for a feature. mutant_examples.append(copied_example) return mutant_features, mutant_examples def mutant_charts_for_feature(example_protos, feature_name, serving_bundles, viz_params): """Returns JSON formatted for rendering all charts for a feature. Args: example_proto: The example protos to mutate. feature_name: The string feature name to mutate. serving_bundles: One `ServingBundle` object per model, that contains the information to make the serving request. viz_params: A `VizParams` object that contains the UI state of the request. Raises: InvalidUserInputError if `viz_params.feature_index_pattern` requests out of range indices for `feature_name` within `example_proto`. Returns: A JSON-able dict for rendering a single mutant chart. parsed in `tf-inference-dashboard.html`. { 'chartType': 'numeric', # oneof('numeric', 'categorical') 'data': [A list of data] # parseable by vz-line-chart or vz-bar-chart } """ def chart_for_index(index_to_mutate): mutant_features, mutant_examples = make_mutant_tuples( example_protos, original_feature, index_to_mutate, viz_params) charts = [] for serving_bundle in serving_bundles: (inference_result_proto, _) = run_inference( mutant_examples, serving_bundle) charts.append(make_json_formatted_for_single_chart( mutant_features, inference_result_proto, index_to_mutate)) return charts try: original_feature = parse_original_feature_from_example( example_protos[0], feature_name) except ValueError as e: return { 'chartType': 'categorical', 'data': [] } indices_to_mutate = viz_params.feature_indices or range( original_feature.length) chart_type = ('categorical' if original_feature.feature_type == 'bytes_list' else 'numeric') try: return { 'chartType': chart_type, 'data': [ chart_for_index(index_to_mutate) for index_to_mutate in indices_to_mutate ] } except IndexError as e: raise common_utils.InvalidUserInputError(e) def make_json_formatted_for_single_chart(mutant_features, inference_result_proto, index_to_mutate): """Returns JSON formatted for a single mutant chart. Args: mutant_features: An iterable of `MutantFeatureValue`s representing the X-axis. inference_result_proto: A ClassificationResponse or RegressionResponse returned by Servo, representing the Y-axis. It contains one 'classification' or 'regression' for every Example that was sent for inference. The length of that field should be the same length of mutant_features. index_to_mutate: The index of the feature being mutated for this chart. Returns: A JSON-able dict for rendering a single mutant chart, parseable by `vz-line-chart` or `vz-bar-chart`. """ x_label = 'step' y_label = 'scalar' if isinstance(inference_result_proto, classification_pb2.ClassificationResponse): # classification_label -> [{x_label: y_label:}] series = {} # ClassificationResponse has a separate probability for each label for idx, classification in enumerate( inference_result_proto.result.classifications): # For each example to use for mutant inference, we create a copied example # with the feature in question changed to each possible mutant value. So # when we get the inferences back, we get num_examples*num_mutants # results. So, modding by len(mutant_features) allows us to correctly # lookup the mutant value for each inference. mutant_feature = mutant_features[idx % len(mutant_features)] for class_index, classification_class in enumerate( classification.classes): # Fill in class index when labels are missing if classification_class.label == '': classification_class.label = str(class_index) # Special case to not include the "0" class in binary classification. # Since that just results in a chart that is symmetric around 0.5. if len( classification.classes) == 2 and classification_class.label == '0': continue key = classification_class.label if index_to_mutate: key += ' (index %d)' % index_to_mutate if not key in series: series[key] = {} mutant_val = ensure_not_binary(mutant_feature.mutant_value) if not mutant_val in series[key]: series[key][mutant_val] = [] series[key][mutant_val].append( classification_class.score) # Post-process points to have separate list for each class return_series = collections.defaultdict(list) for key, mutant_values in iteritems(series): for value, y_list in iteritems(mutant_values): return_series[key].append({ x_label: value, y_label: sum(y_list) / float(len(y_list)) }) return_series[key].sort(key=lambda p: p[x_label]) return return_series elif isinstance(inference_result_proto, regression_pb2.RegressionResponse): points = {} for idx, regression in enumerate(inference_result_proto.result.regressions): # For each example to use for mutant inference, we create a copied example # with the feature in question changed to each possible mutant value. So # when we get the inferences back, we get num_examples*num_mutants # results. So, modding by len(mutant_features) allows us to correctly # lookup the mutant value for each inference. mutant_feature = mutant_features[idx % len(mutant_features)] mutant_val = ensure_not_binary(mutant_feature.mutant_value) if not mutant_val in points: points[mutant_val] = [] points[mutant_val].append(regression.value) key = 'value' if (index_to_mutate != 0): key += ' (index %d)' % index_to_mutate list_of_points = [] for value, y_list in iteritems(points): list_of_points.append({ x_label: value, y_label: sum(y_list) / float(len(y_list)) }) list_of_points.sort(key=lambda p: p[x_label]) return {key: list_of_points} else: raise NotImplementedError('Only classification and regression implemented.') def get_example_features(example): """Returns the non-sequence features from the provided example.""" return (example.features.feature if isinstance(example, tf.train.Example) else example.context.feature) def run_inference_for_inference_results(examples, serving_bundle): """Calls servo and wraps the inference results.""" (inference_result_proto, extra_results) = run_inference( examples, serving_bundle) inferences = wrap_inference_results(inference_result_proto) infer_json = json_format.MessageToJson( inferences, including_default_value_fields=True) return json.loads(infer_json), extra_results def get_eligible_features(examples, num_mutants): """Returns a list of JSON objects for each feature in the examples. This list is used to drive partial dependence plots in the plugin. Args: examples: Examples to examine to determine the eligible features. num_mutants: The number of mutations to make over each feature. Returns: A list with a JSON object for each feature. Numeric features are represented as {name: observedMin: observedMax:}. Categorical features are repesented as {name: samples:[]}. """ features_dict = ( get_numeric_features_to_observed_range( examples)) features_dict.update( get_categorical_features_to_sampling( examples, num_mutants)) # Massage the features_dict into a sorted list before returning because # Polymer dom-repeat needs a list. features_list = [] for k, v in sorted(features_dict.items()): v['name'] = k features_list.append(v) return features_list def sort_eligible_features(features_list, chart_data): """Returns a sorted list of objects representing each feature. The list is sorted by interestingness in terms of the resulting change in inference values across feature values, for partial dependence plots. Args: features_list: A list of eligible features in the format of the return from the get_eligible_features function. chart_data: A dict of feature names to chart data, formatted as the output from the mutant_charts_for_feature function. Returns: A sorted list of the inputted features_list, with the addition of an 'interestingness' key with a calculated number for feature feature. The list is sorted with the feature with highest interestingness first. """ sorted_features_list = copy.deepcopy(features_list) for feature in sorted_features_list: name = feature['name'] charts = chart_data[name] max_measure = 0 is_numeric = charts['chartType'] == 'numeric' for models in charts['data']: for chart in models: for series in chart.values(): if is_numeric: # For numeric features, interestingness is the total Y distance # traveled across the line chart. measure = 0 for i in range(len(series) - 1): measure += abs(series[i]['scalar'] - series[i + 1]['scalar']) else: # For categorical features, interestingness is the difference # between the min and max Y values in the chart, as interestingness # for categorical charts shouldn't depend on the order of items # being charted. min_y = float("inf") max_y = float("-inf") for i in range(len(series)): val = series[i]['scalar'] if val < min_y: min_y = val if val > max_y: max_y = val measure = max_y - min_y if measure > max_measure: max_measure = measure feature['interestingness'] = max_measure return sorted( sorted_features_list, key=lambda x: x['interestingness'], reverse=True) def get_label_vocab(vocab_path): """Returns a list of label strings loaded from the provided path.""" if vocab_path: try: with tf.io.gfile.GFile(vocab_path, 'r') as f: return [line.rstrip('\n') for line in f] except tf.errors.NotFoundError as err: logging.error('error reading vocab file: %s', err) return [] def create_sprite_image(examples): """Returns an encoded sprite image for use in Facets Dive. Args: examples: A list of serialized example protos to get images for. Returns: An encoded PNG. """ def generate_image_from_thubnails(thumbnails, thumbnail_dims): """Generates a sprite atlas image from a set of thumbnails.""" num_thumbnails = tf.shape(thumbnails)[0].eval() images_per_row = int(math.ceil(math.sqrt(num_thumbnails))) thumb_height = thumbnail_dims[0] thumb_width = thumbnail_dims[1] master_height = images_per_row * thumb_height master_width = images_per_row * thumb_width num_channels = 3 master = np.zeros([master_height, master_width, num_channels]) for idx, image in enumerate(thumbnails.eval()): left_idx = idx % images_per_row top_idx = int(math.floor(idx / images_per_row)) left_start = left_idx * thumb_width left_end = left_start + thumb_width top_start = top_idx * thumb_height top_end = top_start + thumb_height master[top_start:top_end, left_start:left_end, :] = image return tf.image.encode_png(tf.cast(master, dtype=tf.uint8)) image_feature_name = 'image/encoded' sprite_thumbnail_dim_px = 32 with tf.compat.v1.Session(): keys_to_features = { image_feature_name: tf.io.FixedLenFeature((), tf.string, default_value=''), } parsed = tf.io.parse_example(examples, keys_to_features) images = tf.zeros([1, 1, 1, 1], tf.float32) i = tf.constant(0) thumbnail_dims = (sprite_thumbnail_dim_px, sprite_thumbnail_dim_px) num_examples = tf.constant(len(examples)) encoded_images = parsed[image_feature_name] # Loop over all examples, decoding the image feature value, resizing # and appending to a list of all images. def loop_body(i, encoded_images, images): encoded_image = encoded_images[i] image = tf.image.decode_jpeg(encoded_image, channels=3) resized_image = tf.image.resize(image, thumbnail_dims) expanded_image = tf.expand_dims(resized_image, 0) images = tf.cond( tf.equal(i, 0), lambda: expanded_image, lambda: tf.concat([images, expanded_image], 0)) return i + 1, encoded_images, images loop_out = tf.while_loop( lambda i, encoded_images, images: tf.less(i, num_examples), loop_body, [i, encoded_images, images], shape_invariants=[ i.get_shape(), encoded_images.get_shape(), tf.TensorShape(None) ]) # Create the single sprite atlas image from these thumbnails. sprite = generate_image_from_thubnails(loop_out[2], thumbnail_dims) return sprite.eval() def run_inference(examples, serving_bundle): """Run inference on examples given model information Args: examples: A list of examples that matches the model spec. serving_bundle: A `ServingBundle` object that contains the information to make the inference request. Returns: A tuple with the first entry being the ClassificationResponse or RegressionResponse proto and the second entry being a dictionary of extra data for each example, such as attributions, or None if no data exists. """ batch_size = 64 if serving_bundle.estimator and serving_bundle.feature_spec: # If provided an estimator and feature spec then run inference locally. preds = serving_bundle.estimator.predict( lambda: tf.data.Dataset.from_tensor_slices( tf.io.parse_example([ex.SerializeToString() for ex in examples], serving_bundle.feature_spec)).batch(batch_size)) # Use the specified key if one is provided. key_to_use = (serving_bundle.predict_output_tensor if serving_bundle.use_predict else None) values = [] for pred in preds: if key_to_use is None: # If the prediction dictionary only contains one key, use it. returned_keys = list(pred.keys()) if len(returned_keys) == 1: key_to_use = returned_keys[0] # Use default keys if necessary. elif serving_bundle.model_type == 'classification': key_to_use = 'probabilities' else: key_to_use = 'predictions' if key_to_use not in pred: raise KeyError( '"%s" not found in model predictions dictionary' % key_to_use) values.append(pred[key_to_use]) return (common_utils.convert_prediction_values(values, serving_bundle), None) elif serving_bundle.custom_predict_fn: # If custom_predict_fn is provided, pass examples directly for local # inference. sig = signature(serving_bundle.custom_predict_fn) params = sig.parameters # The custom_predict_fn for colab/jupyter accepts one parameter. # While the custom_predict_fn for non-colab usage have two. if len(params) == 1: values = serving_bundle.custom_predict_fn(examples) if len(params) == 2: values = serving_bundle.custom_predict_fn(examples, serving_bundle) extra_results = None # If the custom prediction function returned a dict, then parse out the # prediction scores. If it is just a list, then the results are the # prediction results without attributions or other data. if isinstance(values, dict): preds = values.pop('predictions') extra_results = values else: preds = values return (common_utils.convert_prediction_values(preds, serving_bundle), extra_results) else: return (platform_utils.call_servo(examples, serving_bundle), None)