Batch Normalization
Last updated
Last updated
Batch normalization is invented and widely popularized by the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. In deep neural network, activations between neural layers are extremely dependent on the parameter initialization, which in turn affects how outputs are backprop into each layer during training. Poor initialization can greatly affect how well a network is trained and how fast it can be trained. Networks train best when each layer has an unit Gaussian distribution for its activations. So if you really want unit Gaussian activations, you can make them so by applying batch normalization to every layer.
Basically, batch normalization is a powerful technique for decoupling the weight updates from parameter initialization. Quoted from the paper, batch normalization allows us to use much higher learning rates and be less careful about initialization. Let's consider a batch of activations at some layer, we can make each dimension (denoted by ) unit Gaussian by applying:
Each batch of training example has dimension D
. Compute the empirical mean and variance independently for each dimension by using all the training data. Batch normalization is usually inserted after fully connected or convolutional layers and before nonlinearity is applied. For the convolutional layer, we are basically going to have one mean and one standard deviation per activation map that we have. And then we are going to normalize across all of the examples in the batch of data.
If we have a tanh
layer, we don't really want to constraint it to the linear regime. The act of normalization might force us to stay within the center, which is known as the linear regime. We want flexibility so ideally we should learn batch normalization as a paramter of the network. In other words, we should insert a parameter which can be learned to effectively cancel out batch normalization if the network sees fit.
We will apply the following operation to each normalized vector:
Such that the network can learn
And effectively recover the identity mapping as if you didn't have batch normalization, i.e. to cancel out the batch normalization if the network sees fit.
Find mini-batch mean:
Find mini-batch variance:
Normalize:
Scale and shift:
Improves gradient flow through the network
Allows higher learning rates
Reduces the strong dependence on initialization
Acts as a form of regularization in a funny way, and slightly reduces the need for dropout
Here comes the derivation; much of the derivation comes from the paper itself and also from Kevin Zakka's blog on Github.
BN stands for batch normalization
Forward pass is very easy intuitively and mathematically.
First we find the mean across a mini-batch of training examples
Find the variance across the same mini-batch of training examples
And then apply normalization
Finally, apply linear transformation with learned parameters to enable network to recover identity. In case we wonder why do we need to do this.
Note that simply normalizing each input of a layer may change what the layer can represent. For instance, normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity. To address this, w make sure that the transformation inserted in the network can represent the identity transform.
As we can see that the output has a mean centered at zero.
And variance of one across all examples.
Now here comes the hard part. We are given an upstream gradient, i.e. the gradient of loss function w.r.t to output of the batch normalization layer.
We need to find
The derivative of $y$ with respect to $\hat{x}$ is simple:
Thus,
In Python,
Thus,
We need to perform a sum across all training examples in the mini-batch and squash the shape (N, M)
to (M,)
In Python,
Thus,
We need to perform a sum across all training examples in the mini-batch and squash the shape (N, M)
to (M,)
In Python,
Thus,
We need to perform a sum across all training examples in the mini-batch and squash the shape (N, M)
to (M,)
In Python,
We are going to use chain rule to solve for this gradient:
Thus,
In Python,
Use chain rule again to solve for the final gradient:
Now fill in the missing pieces:
Now we just plug and chuck
In Python,
Work on this later...
Inputs: Values of over a mini-batch: B =
Outputs:
is the input matrix/vector to the BN layer
is the batch mean
is the batch variance
is a small constant added to avoid dividing by zero
is the normalized input matrix/vector
is the linear transformation which scales by and
represents the next layer after BN layer, if we assume a forward pass ordering
If is 1 and is 0 then the linear transformation is an identity transformation.
The derivative of with respect to is:
The derivative of with respect to is:
The derivative of with respect to is: