1

I've been reading that transformer decoders use masked self attention so that the decoder can't cheat by looking ahead. For example, when predicting the 6th token in the sequence we shouldn't have access to the 7th token.

However, why can't the decoder perform full self attention on all previously predicted tokens? When predicting the 6th token, why can't the third token embedding have access to the 5th token. Wouldn't this system of representation offer richer context. Some explanations that I have seen online have stated that this system would violate the nature of autoregressive token generation, however we still aren't looking at the 7th token or anything after to predict the sixth token, we are just allowing all the already predicted tokens to attend to each other. The presence of every single token in a generated sequence is only the result of everything that came before it which still sounds very autoregressive.

In this previous post: What if we drop the causal mask in auto-regressive Transformer?

The answer mentions: Allowing available tokens to attend to each other would violate the autoregressive property and potentially introduce information leakage from future tokens, leading to incorrect predictions.

I'm not sure what this really means or where exactly the information leakage would be coming from, since the 6th token would have no information about the 7th. I know that doing self attention like this increases the complexity, however is there any actual accuracy or quality reasons why we don't do this.

1 Answers1

1

Its all about speed.


During training:
you use teacher forcing and you feed the entire target sequence $Y$ to the decoder (say of length $N$). You want the decoder to attend to y[:i] when predicting y[i]. So you mask out y[i:] and compute the embeddings in each of the transformer decoder layers.

Now when predicting y[i+1] you are allowed to attend to y[:i+1]. What you do is you reuse the embeddings from the previous run. And you only compute the embedding of y[i] by allowing it to attend to everything in y[:i+1]. This means that the embedding of y[i-1] for example will be computed without attending to y[i] even though it is allowed to.

We do this because this allows the decoding of the entire sequence to be computed much more efficiently. Assume that $Y$ has shape Y.shape = (N, D). The embeddings produced by the first self-attention layer are given by:

$$ Z = \text{masked_softmax}\bigg( \frac{YQ @ K^T Y^T}{\sqrt{D}} \bigg) @ YV, $$ where $Q, K, V$ are the query, key, value weights of your layer and Z.shape = (N,D). Now these embeddings are forwarded to the next self-attention layer and so on..

Obviously you have only one embedding for the each element of the sequence and this embedding is computed by attending only to previous elements (and itself). The cost for computing the embeddings is $\mathcal{O}(N^2)$.

Now if you want your 3rd element to have access to the 5th element when predicting the 6th element, then you need to have a different embedding for the 3rd element specifically for this computation. That is your $Z$ output of the self-attention layer has to be Z.shape = (N, N, D).

So the embedding of the 3rd element when predicting the 4th element will be $Z[2,3]$ and it will be computed by attending to elements 0,1,2. And the embedding of the 3rd element when predicting the 6th element will be $Z[2, 5]$ and it will be computed by attending to elements 0,1,2,3,4. In this case the cost for computing the embeddings would be $\mathcal{O}(N^3)$.


During inference:
you actually run the decoder sequentially in a loop. With every generated element you forward the entire sequence through the decoder in order to get only the embedding of the newly generated element. Here is a snippet of a simple greedy decoding procedure that I wrote:
https://pi-tau.github.io/posts/transformer/#inference
One could argue that at every step you can actually re-compute the embeddings of all elements (not only the new one), so that they attend to the newly-generated element. Probably there will be no computational overhead because the entire computation is batched in a single matrix-matrix multiplication.
However, note that your model was not trained in this setting, so it is not really clear if this would improve or worsen the performance.

pi-tau
  • 995
  • 6
  • 12