# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Resampling dataset transformations.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.experimental.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.util.tf_export import tf_export @tf_export("data.experimental.rejection_resample") def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): """A transformation that resamples a dataset to achieve a target distribution. **NOTE** Resampling is performed via rejection sampling; some fraction of the input values will be dropped. Args: class_func: A function mapping an element of the input dataset to a scalar `tf.int32` tensor. Values should be in `[0, num_classes)`. target_dist: A floating point type tensor, shaped `[num_classes]`. initial_dist: (Optional.) A floating point type tensor, shaped `[num_classes]`. If not provided, the true class distribution is estimated live in a streaming fashion. seed: (Optional.) Python integer seed for the resampler. Returns: A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") # Get initial distribution. if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") acceptance_dist, prob_of_original = ( _calculate_acceptance_probs_with_mixing(initial_dist_t, target_dist_t)) initial_dist_ds = dataset_ops.Dataset.from_tensors( initial_dist_t).repeat() acceptance_dist_ds = dataset_ops.Dataset.from_tensors( acceptance_dist).repeat() prob_of_original_ds = dataset_ops.Dataset.from_tensors( prob_of_original).repeat() else: initial_dist_ds = _estimate_initial_dist_ds(target_dist_t, dataset.map(class_func)) acceptance_and_original_prob_ds = initial_dist_ds.map( lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda initial, target_dist_t)) acceptance_dist_ds = acceptance_and_original_prob_ds.map( lambda accept_prob, _: accept_prob) prob_of_original_ds = acceptance_and_original_prob_ds.map( lambda _, prob_original: prob_original) filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed) # Prefetch filtered dataset for speed. filtered_ds = filtered_ds.prefetch(3) prob_original_static = _get_prob_original_static( initial_dist_t, target_dist_t) if initial_dist is not None else None def add_class_value(*x): if len(x) == 1: return class_func(*x), x[0] else: return class_func(*x), x if prob_original_static == 1: return dataset.map(add_class_value) elif prob_original_static == 0: return filtered_ds else: return interleave_ops.sample_from_datasets( [dataset.map(add_class_value), filtered_ds], weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), seed=seed) return _apply_fn def _get_prob_original_static(initial_dist_t, target_dist_t): """Returns the static probability of sampling from the original. `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters an Op that it isn't defined for. We have some custom logic to avoid this. Args: initial_dist_t: A tensor of the initial distribution. target_dist_t: A tensor of the target distribution. Returns: The probability of sampling from the original distribution as a constant, if it is a constant, or `None`. """ init_static = tensor_util.constant_value(initial_dist_t) target_static = tensor_util.constant_value(target_dist_t) if init_static is None or target_static is None: return None else: return np.min(target_static / init_static) def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed): """Filters a dataset based on per-class acceptance probabilities. Args: dataset: The dataset to be filtered. acceptance_dist_ds: A dataset of acceptance probabilities. initial_dist_ds: A dataset of the initial probability distribution, given or estimated. class_func: A function mapping an element of the input dataset to a scalar `tf.int32` tensor. Values should be in `[0, num_classes)`. seed: (Optional.) Python integer seed for the resampler. Returns: A dataset of (class value, data) after filtering. """ def maybe_warn_on_large_rejection(accept_dist, initial_dist): proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) return control_flow_ops.cond( math_ops.less(proportion_rejected, .5), lambda: accept_dist, lambda: logging_ops.Print( # pylint: disable=g-long-lambda accept_dist, [proportion_rejected, initial_dist, accept_dist], message="Proportion of examples rejected by sampler is high: ", summarize=100, first_n=10)) acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, initial_dist_ds)) .map(maybe_warn_on_large_rejection)) def _gather_and_copy(acceptance_prob, data): if isinstance(data, tuple): class_val = class_func(*data) else: class_val = class_func(data) return class_val, array_ops.gather(acceptance_prob, class_val), data current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( (acceptance_dist_ds, dataset)).map(_gather_and_copy) filtered_ds = ( current_probabilities_and_class_and_data_ds.filter( lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) return filtered_ds.map(lambda class_value, _, data: (class_value, data)) def _estimate_initial_dist_ds( target_dist_t, class_values_ds, dist_estimation_batch_size=32, smoothing_constant=10): num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0]) initial_examples_per_class_seen = array_ops.fill( [num_classes], np.int64(smoothing_constant)) def update_estimate_and_tile(num_examples_per_class_seen, c): updated_examples_per_class_seen, dist = _estimate_data_distribution( c, num_examples_per_class_seen) tiled_dist = array_ops.tile( array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) return updated_examples_per_class_seen, tiled_dist initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) .apply(scan_ops.scan(initial_examples_per_class_seen, update_estimate_and_tile)) .unbatch()) return initial_dist_ds def _get_target_to_initial_ratio(initial_probs, target_probs): # Add tiny to initial_probs to avoid divide by zero. denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) return target_probs / denom def _estimate_data_distribution(c, num_examples_per_class_seen): """Estimate data distribution as labels are seen. Args: c: The class labels. Type `int32`, shape `[batch_size]`. num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing counts. Returns: num_examples_per_lass_seen: Updated counts. Type `int64`, shape `[num_classes]`. dist: The updated distribution. Type `float32`, shape `[num_classes]`. """ num_classes = num_examples_per_class_seen.get_shape()[0] # Update the class-count based on what labels are seen in batch. num_examples_per_class_seen = math_ops.add( num_examples_per_class_seen, math_ops.reduce_sum( array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) init_prob_estimate = math_ops.truediv( num_examples_per_class_seen, math_ops.reduce_sum(num_examples_per_class_seen)) dist = math_ops.cast(init_prob_estimate, dtypes.float32) return num_examples_per_class_seen, dist def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): """Calculates the acceptance probabilities and mixing ratio. In this case, we assume that we can *either* sample from the original data distribution with probability `m`, or sample from a reshaped distribution that comes from rejection sampling on the original distribution. This rejection sampling is done on a per-class basis, with `a_i` representing the probability of accepting data from class `i`. This method is based on solving the following analysis for the reshaped distribution: Let F be the probability of a rejection (on any example). Let p_i be the proportion of examples in the data in class i (init_probs) Let a_i is the rate the rejection sampler should *accept* class i Let t_i is the target proportion in the minibatches for class i (target_probs) ``` F = sum_i(p_i * (1-a_i)) = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 ``` An example with class `i` will be accepted if `k` rejections occur, then an example with class `i` is seen by the rejector, and it is accepted. This can be written as follows: ``` t_i = sum_k=0^inf(F^k * p_i * a_i) = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 = p_i * a_i / sum_j(p_j * a_j) using F from above ``` Note that the following constraints hold: ``` 0 <= p_i <= 1, sum_i(p_i) = 1 0 <= a_i <= 1 0 <= t_i <= 1, sum_i(t_i) = 1 ``` A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` If we try to minimize the amount of data rejected, we get the following: M_max = max_i [ t_i / p_i ] M_min = min_i [ t_i / p_i ] The desired probability of accepting data if it comes from class `i`: a_i = (t_i/p_i - m) / (M_max - m) The desired probability of pulling a data element from the original dataset, rather than the filtered one: m = M_min Args: initial_probs: A Tensor of the initial probability distribution, given or estimated. target_probs: A Tensor of the corresponding classes. Returns: (A 1D Tensor with the per-class acceptance probabilities, the desired probability of pull from the original distribution.) """ ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) max_ratio = math_ops.reduce_max(ratio_l) min_ratio = math_ops.reduce_min(ratio_l) # Target prob to sample from original distribution. m = min_ratio # TODO(joelshor): Simplify fraction, if possible. a_i = (ratio_l - m) / (max_ratio - m) return a_i, m