This is the continuation of my backpropagation post. This time, I will recap how I implemented backpropagation for batch normalization.
Batch Normalization
Batch normalization is a technique used to reduce the dependency between weight initialization and training results. With a BN layer, it is possible to ensure roughly unit gaussian distributed outputs before a nonlinear activation function such as ReLU is is applied. Or at least, that’s the case if you apply BN before ReLU, but I have also seen people use BN after ReLU in which case the activations are unit gaussian. It would be interesting to know in which situations BN placement makes a difference and why. For now, let’s return to the topic of backpropagation.
Batch normalization takes the input, which consists of samples \(x_i \in \mathbb{R}^D, i\in\{1,\dots,N\}\) and normalizes every feature column by
\[x^{\text{bn}}_{ik} = \gamma_k \cdot \frac{x_{ik}-\overline{x_k}}{\sqrt{\sigma^2_k + \epsilon}} + \beta_k \quad , \quad k\in\{1,\dots,D\}\]
where \(\overline{x_k}\) and \(\sigma^2_k\) refer to the mean and variance of the \(k\)-th column. The parameters \(\gamma_k\) and \(\beta_k\) are learned by backpropagation.
Our first goal is to calculate the derivative
\[\frac{dx^{\text{bn}}_{ik}}{dx_{lk}} \quad , \quad l\in\{1,\dots,N\} .\]
Both the column mean and variance are functions of \(x_{lk}\), so we need to derive \(d\overline{x_k}/dx_{lk}\) and \(d\sigma^2_k/dx_{lk}\) first. For the column mean, this is straightforward:
\[\frac{d\overline{x_k}}{dx_{lk}} = \frac{1}{N} \quad .\]
For the variance, a bit more effort is required
\begin{align}
\frac{d\sigma_{k}^2}{dx_{lk}} &= \frac{d}{dx_{lk}} \frac{\sum_{i=1}^N (x_{ik}-\overline{x_k})^2}{N} \newline
&= \frac{2 \sum_{i=1}^N (x_{ik}-\overline{x_k}) \cdot (\delta(i=l)-\frac{1}{N})}{N} \newline
&= \frac{2 \big((x_{lk}-\overline{x_k})-\frac{1}{N} \sum_{i=1}^N (x_{ik}-\overline{x_k}) \big)}{N} \newline
&= \frac{2(x_{lk}-\overline{x_k})}{N}
\end{align}
We can now use the above expressions to calculate the complete derivation with the quotient rule
\begin{align}
\frac{dx_{ik}^{\text{bn}}}{dx_{lk}} &= \gamma_k \cdot \Bigg( \frac{\delta(i=l)-1/N}{\sqrt{\sigma^2_k + \epsilon} } - \frac{(x_{ik}-\overline{x_k})\cdot (x_{lk}-\overline{x_k})}{N\sqrt{\sigma^2_k + \epsilon}^3}\Bigg)
\end{align}
The above expression describes the inner gradient, which still needs to be chained to the upstream gradient \(dout/dx^{\text{bn}}_{jk}\) for \(j\in\{1,\dots,N\}, k\in\{1,\dots,D\}\):
\begin{align}
\frac{dout}{dx_{lk}} &= \sum_{j=1}^N \Big( \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \frac{dx_{jk}^{\text{bn}}}{dx_{lk}} \Big) \newline
&= \sum_{j=1}^N \Bigg( \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \gamma_k \cdot\Bigg( \frac{\delta(j=l)-1/N}{\sqrt{\sigma^2_k + \epsilon} } - \frac{(x_{jk}-\overline{x_k})\cdot (x_{lk}-\overline{x_k})}{N\sqrt{\sigma^2_k + \epsilon}^3} \Bigg) \Bigg)\newline
&= \frac{\gamma_k}{\sqrt{\sigma^2_k + \epsilon}} \cdot \Bigg( \sum_{j=1}^N \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \big( \delta(j=l)-1/N \big) - \sum_{j=1}^N \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \frac{(x_{jk}-\overline{x_k})\cdot (x_{lk}-\overline{x_k})}{N (\sigma^2_k + \epsilon)} \Bigg) \newline
&= \frac{\gamma_k}{\sqrt{\sigma^2_k + \epsilon}} \cdot \Bigg( \frac{dout}{dx^{\text{bn}}_{lk}} - \frac{1}{N} \sum_{j=1}^N \frac{dout}{dx^{\text{bn}}_{jk}} - \frac{1}{N} \sum_{j=1}^N \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \frac{(x_{jk}-\overline{x_k})\cdot (x_{lk}-\overline{x_k})}{\sigma^2_k + \epsilon} \Bigg)
\end{align}
The gradient with respect to \(\gamma_k\) is given by
\begin{align}
\frac{dout}{d\gamma_k} &= \sum_{j=1}^N \Big( \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \frac{dx_{jk}^{\text{bn}}}{d\gamma_{k}} \Big) \newline
&= \sum_{j=1}^N \Big( \frac{dout}{dx^{\text{bn}}_{jk}} \cdot \frac{x_{jk}-\overline{x_k}}{\sqrt{\sigma_k^2+\epsilon}} \Big) \quad .
\end{align}
The gradient with respect to \(\beta_k\) is given by
\begin{align}
\frac{dout}{d\beta_k} = \sum_{j=1}^N \frac{dout}{dx^{\text{bn}}_{jk}} \quad .
\end{align}
Plugging everything in will make the implementation of a BN layer look something like this in Python:
Further Reading
[1] Deriving Batch-Norm Backprop equations by Chris Yeh is an alternative derivation of the vectorized equations