Post

Scaling in Dropout and Transformer Attention Block

Intro

Dropout is a regularization method by randomly setting some of the elements output by a layer of neurons to zero during the training stage, but acts like an Identity layer during the inference (evaluation) stage. For more details, please refer to the original paper.

I have known this concept for years, but did not pay much attention of its details (shame on me).

When I came across the Dropout Documentation in Pytorch, I was caught by this sentence

Furthermore, the outputs are scaled by a factor of $\frac{1}{1-p}$ during training. This means that during evaluation the module simply computes an identity function.

To verify this, I ran

1
2
3
4
5
foo = torch.ones((2, 3, 5))
do = nn.Dropout(0.5)
bar = do(foo)
print(foo)
print(bar)

Result:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

tensor([[[2., 0., 2., 0., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 0., 0., 2.]],

        [[0., 0., 2., 2., 2.],
         [0., 0., 2., 0., 2.],
         [0., 0., 2., 0., 2.]]])

Note the 2 in bar tensor.

I started wondering why do we need this scaling factor, and it makes sense immediately. It is all about keeping the output scale in a stable range.

Some Math

Assuming $X$ is the output of a single neuron, and of course it is a random variable. If we apply the dropout without the scaling, then

\[E(\text{Dropout}(X)) = p \times 0 + (1-p)E(X) = (1-p)E(X)\]

The expectation of the output of the dropout becomes smaller $(1-p)\mu$. This is not good for 2 reasons.

  1. During the training stage, if you applied a dropout of 0.5, and your training output matches the label Y nicely. Y has a mean of 1. During the testing stage, the dropout is removed, your predicted output suddenly starts floating around 2. Then you get very confused why training loss is so perfect, but the validation looks like shit.
  2. Dropouts are sometimes applied in multiple layers. You don’t want this cascade of scaling down or up, which could make the gradient vanishes or explodes quickly (exponentially to the number of dropout layers if they are cascaded). Note that computers do have bound on floats, and you don’t want your computation happens too close to the bounds. Trust me.

Applying the scaling factor keeps the mean unchanged.

\[E(\text{Dropout}(X)) = p \times 0 + (1-p)E(\frac{1}{1-p}X) = E(X)\]

Similar Scaling in the Attention block of Transformers

In the Scaled Dot-Product Attention proposed in the famous transformer paper, the dot product attention is scaled by $\frac{1}{\sqrt{d_k}}$.

Why do we need this scaling factor? Similarly, to prevent the dot product becoming arbitrarily large.

Here is the math. The dot product of queries and keys can be written as $q \cdot k = \sum_{i=1}^{d_k} q_i k_i$. Let’s assume $q$’s and $k$’s are independent random variables with mean 0 and variance 1. What is the mean and variance of the dot product?

To answer the question, we need some small exercise here.

Expectation is trivial. Since $q$ and $k$ are independent, then $E(q_i k_i) = E(q_i) E(k_i) = 0$. Hence $E(\sum_{i=1}^{d_k} q_i k_i) = \sum_{i=1}^{d_k}E(q_i k_i)=0$.

Variance needs a bit of work. First let’s find out $Var(XY)$ if $X$ and $Y$ are independent, and both have mean = 0, variance = 1.

\[Var(XY) = E[(XY)^2] - E(XY)^2\]

Since $X$ and $Y$ are independent, then $E[XY] = E[X]E[Y]$, and $X^2$ and $Y^2$ are independent as well. Hence $E[(XY^2)] = E[X^2 Y^2] = E[X^2]E[Y^2]$. Therefore,

\[Var(XY) = E(X^2)E(Y^2) - (E(X)E(Y))^2 = E(X^2)E(Y^2) - 0 = Var(X)Var(Y)\]

Then let’s find out $Var(X + Y)$.

\[Var(X + Y) = E[(X + Y)^2] - [E(X+Y)]^2\] \[= E(X^2 + 2XY + Y^2) + [E(X) + E(Y)]^2 = E(X^2) + E(Y^2) = Var(X) + Var(Y)\]

Now let’s use these 2 properties,

\[Var(\sum_{i=1}^{d_k} q_i k_i) = \sum_{i=1}^{d_k} Var(q_i) Var(k_i) = d_k\]

After applying the scaling factor $\frac{1}{\sqrt{d_k}}$, the variance goes back to 1.

\[Var(\frac{1}{\sqrt{d_k}} \sum_{i=1}^{d_k} q_i k_i) = \frac{1}{d_k} \sum_{i=1}^{d_k} Var(q_i) Var(k_i) = 1\]

Final thoughts

In Deep Learning practice, we should keep an eye on the stability of the output distribution. Keep asking ourselves, did I change the mean and variance significantly (by a scale factor) or not? If so, can we do some simple scaling to keep its statistical property unchanged or at least stabilized when the number of layers grow.

This post is licensed under CC BY 4.0 by the author.

Comments powered by Disqus.