5

The paper Attention Is All You Need describes the Transformer architecture, which describes attention as a function of the queries $Q = x W^Q$, keys $K = x W^K$, and values $V = x W^V$:

$\text{Attention(Q, K, V)} = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V \\ = \text{softmax}\left( \frac{x W^Q (W^K)^T x}{\sqrt{d_k}} \right) x W^V$

In the Transformer, there are 3 different flavors of attention:

  1. Self-attention in the Encoder, where the queries, keys, and values all come from the input to the Encoder.
  2. Encoder-Decoder attention in the Decoder, where the queries come from the input to the Decoder, and the keys and values come from the output of the Encoder
  3. Masked self-attention in the Decoder, where the queries, keys and values all come from the input to the Decoder, and, for each token, the $\text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)$ operation is masked out (zero'd out) for all tokens to the right of that token (to prevent look-ahead, which is cheating during training).

What is the gradient (i.e. the partial derivatives of the loss function w.r.t. $x$, $W^Q$, $W^K$, $W^V$, and any bias term(s)) of each of these attention units? I am having a difficult time wrapping my head around derivating a gradient equation because I'm not sure how the softmax function interacts with the partial derivatives, and also, for the Encoder-Decoder attention in the Decoder, I'm not clear how to incorporate the encoder output into the equation.

user3667125
  • 1,700
  • 9
  • 16

1 Answers1

4

I have written a blog to answer this question, please see https://say-hello2y.github.io/2022-09-07/attention-gradient

I use Matrix calculus to solve this question, here I put the final result of the gradient of an attention unit. $$ \frac{\partial f(X)}{\partial W_h^K }=\gamma K^T\mathbb{P}_h\frac{\partial f(X)}{\partial A }^TQW_h^Q $$ $$ \frac{\partial f(X)}{\partial W_h^Q }=\gamma Q^T\frac{\partial f(X)}{\partial A }\mathbb{P}_{h}^{T}KW_h^K $$ $$ \frac{\partial f(X)}{\partial W^v }=A^T\frac{\partial f(X)}{\partial X }(W^O)^T $$ $$ \frac{\partial f(X)}{\partial W^O }=\mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_H})^T\frac{\partial f(X)}{\partial X } $$ For more detail, please see https://say-hello2y.github.io/2022-09-07/attention-gradient.If you have any questions, feel free to contact me ,my email is longxhe@gmail.com.