# -*- coding: utf-8 -*- """Normalization layers. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from ..engine.base_layer import Layer, InputSpec from .. import initializers from .. import regularizers from .. import constraints from .. import backend as K from ..legacy import interfaces class BatchNormalization(Layer): """Batch normalization layer (Ioffe and Szegedy, 2014). Normalize the activations of the previous layer at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. # Arguments axis: Integer, the axis that should be normalized (typically the features axis). For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. momentum: Momentum for the moving mean and the moving variance. epsilon: Small float added to variance to avoid dividing by zero. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling will be done by the next layer. beta_initializer: Initializer for the beta weight. gamma_initializer: Initializer for the gamma weight. moving_mean_initializer: Initializer for the moving mean. moving_variance_initializer: Initializer for the moving variance. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. beta_constraint: Optional constraint for the beta weight. gamma_constraint: Optional constraint for the gamma weight. # Input shape Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. # Output shape Same shape as input. # References - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167) """ @interfaces.legacy_batchnorm_support def __init__(self, axis=-1, momentum=0.99, epsilon=1e-3, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs): super(BatchNormalization, self).__init__(**kwargs) self.supports_masking = True self.axis = axis self.momentum = momentum self.epsilon = epsilon self.center = center self.scale = scale self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.moving_mean_initializer = initializers.get(moving_mean_initializer) self.moving_variance_initializer = ( initializers.get(moving_variance_initializer)) self.beta_regularizer = regularizers.get(beta_regularizer) self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_constraint = constraints.get(beta_constraint) self.gamma_constraint = constraints.get(gamma_constraint) def build(self, input_shape): dim = input_shape[self.axis] if dim is None: raise ValueError('Axis ' + str(self.axis) + ' of ' 'input tensor should have a defined dimension ' 'but the layer received an input with shape ' + str(input_shape) + '.') self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim}) shape = (dim,) if self.scale: self.gamma = self.add_weight(shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) else: self.gamma = None if self.center: self.beta = self.add_weight(shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint) else: self.beta = None self.moving_mean = self.add_weight( shape=shape, name='moving_mean', initializer=self.moving_mean_initializer, trainable=False) self.moving_variance = self.add_weight( shape=shape, name='moving_variance', initializer=self.moving_variance_initializer, trainable=False) self.built = True def call(self, inputs, training=None): input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None return K.batch_normalization( inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, axis=self.axis, epsilon=self.epsilon) else: return K.batch_normalization( inputs, self.moving_mean, self.moving_variance, self.beta, self.gamma, axis=self.axis, epsilon=self.epsilon) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() # If the learning is either dynamic, or set to training: normed_training, mean, variance = K.normalize_batch_in_training( inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) if K.backend() != 'cntk': sample_size = K.prod([K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) if K.backend() == 'tensorflow' and sample_size.dtype != 'float32': sample_size = K.cast(sample_size, dtype='float32') # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum)], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training) def get_config(self): config = { 'axis': self.axis, 'momentum': self.momentum, 'epsilon': self.epsilon, 'center': self.center, 'scale': self.scale, 'beta_initializer': initializers.serialize(self.beta_initializer), 'gamma_initializer': initializers.serialize(self.gamma_initializer), 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_constraint': constraints.serialize(self.gamma_constraint) } base_config = super(BatchNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): return input_shape