7

My understanding is that masked self-attention is necessary during training of GPT-2, as otherwise it would be able to directly see the correct next output at each iteration. My question is whether the attention mask is necessary, or even possible, during inference. As GPT-2 will only be producing one token at a time, it doesn't make sense to mask out future tokens that haven't been inferred yet.

orome
  • 113
  • 6
D_s
  • 71
  • 1
  • 3

3 Answers3

3

Answer to Q1) If sampling for next token do you need to apply mask during inference?

Yes you do! The models ability to transfer information across positions was trained in this manner, and changing it up will have unpredictable consequences. Let my try to give an example:

Tokens: 1:sally, 2:sold, 3:seashells, 4:on, 5:the, 6:____
In the above you are trying to predict 6 from {1:5}

Denote $n^{(m)}$ as the set of tokens the $n^{th}$ positional embedding has info from at the $m^{th}$ layer.

In both cases we see that $n^{(0)} = \{n\} \ \ \forall n$. Now though with a mask we get $n^{(i)} = \{k\}_{k\leq n} \ \ \forall n \ \ s.t. \ \ i \geq 1$ but without we see $n^{(i)} = \{k\}_{k \in [1:N]} \ \ \forall n$. This difference means at the final layer the mebeddings going in will differ completely, and unless we train for such an approach it will cause error

Answer to Q2) What is the sample dimension?

It took me a couple reads to understand what youre asking for but I think I understand. The sample at each step is drawn from a distribution where its logits are linearly associated to a single embedding of dimension $d_{(model)}$ therefore that is our upper bound: $dim(sample) \leq d_{(model)}$ which in the example you gave is 768.

mshlis
  • 2,399
  • 9
  • 23
1

Yes, it is still needed. In fact, @mshlis has given a clear reason in the Answer to Q1. As I am new to the community and have no reputation to directly comment, let me supply his/her answer by opening up my answer.

Empirical Verification

First, in response to @pi-tau, yes, the model performs really worse without causal masking than with causal masking during inference according to my observations. For example, assuming you are using Python, one can comment out the related lines in the used package (typically located in a directory with a path like anaconda3/lib/python3.7/site-packages/transformers/models/gpt2/modeling_gpt2.py and similar to the ones shown below, whose exact forms depend on your own environment and torch version) to experiment without causal masking.

if not self.is_cross_attention:
    # if only "normal" attention layer implements causal mask
    query_length, key_length = query.size(-2), key.size(-2)
    causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
    attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

One may also refer to the official implementation here.

Intuition

Second, regarding the confusion of @dingx, yes, there was something wrong. Let me explain it with a concrete example while reusing the notation introduced by @mshlis's answer. We will look into the case where the GPT-2 now takes 5 tokens (''sally sold seashells on the'') as input to predict the 6-th token (assuming each word is a token).

Without causal masking, for each layer, the information that each embedding has can be listed as follows:

  • 1st layer: $1^{(1)} = \{1^{(0)}, 2^{(0)}, 3^{(0)}, 4^{(0)}, 5^{(0)}\}$, $2^{(1)} = \{1^{(0)}, 2^{(0)}, 3^{(0)}, 4^{(0)}, 5^{(0)}\}$, ..., $5^{(1)} = \{1^{(0)}, 2^{(0)}, 3^{(0)}, 4^{(0)}, 5^{(0)}\}$
  • 2nd layer: $1^{(2)} = \{1^{(1)}, 2^{(1)}, 3^{(1)}, 4^{(1)}, 5^{(1)}\}$, $2^{(2)} = \{1^{(1)}, 2^{(1)}, 3^{(1)}, 4^{(1)}, 5^{(1)}\}$, ..., $5^{(2)} = \{1^{(1)}, 2^{(1)}, 3^{(1)}, 4^{(1)}, 5^{(1)}\}$
  • ...
  • 6th layer: $1^{(6)} = \{1^{(5)}, 2^{(5)}, 3^{(5)}, 4^{(5)}, 5^{(5)}\}$, $2^{(6)} = \{1^{(5)}, 2^{(5)}, 3^{(5)}, 4^{(5)}, 5^{(5)}\}$, ..., $5^{(6)} = \{1^{(5)}, 2^{(5)}, 3^{(5)}, 4^{(5)}, 5^{(5)}\}$
  • Predicting the 6-th token uses

\begin{align*} 5^{(6)} &= \{1^{(5)}, 2^{(5)}, 3^{(5)}, 4^{(5)}, 5^{(5)}\} \\ &= \{ \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\} \} \\ &= \dots \end{align*}

With causal masking, things are quite different:

  • 1st layer: $1^{(1)} = \{1\}$, $2^{(1)} = \{1, 2\}$, ..., $5^{(1)} = \{1, 2, 3, 4, 5\}$
  • 2nd layer: $1^{(2)} = \{1\}$, $2^{(2)} = \{1, 2\}$, ..., $5^{(1)} = \{1, 2, 3, 4, 5\}$
  • ...
  • 6th layer: $1^{(6)} = \{1\}$, $2^{(6)} = \{1, 2\}$, ..., $5^{(6)} = \{1, 2, 3, 4, 5\}$
  • Predicting the 6-th token uses

\begin{align*} 5^{(6)} &= \{1^{(5)}, 2^{(5)}, 3^{(5)}, 4^{(5)}, 5^{(5)}\} \\ &= \{ \{1^{(4)}\}, \{1^{(4)}, 2^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\} \} \\ &= \dots \end{align*}

While both extensions are not complete, one should first be clear about (1) that there indeed exists a difference in the prediction result between the two methods, and (2) that this difference has nothing to do with whether the GPT-2 model can see the so-called "future tokens" that have not been generated (i.e., $6^{(m)}$ has never appeared in the above equations for any $m$).

But why the result inferred with causal masking is preferred over that without? It is because that way is exactly how the shipped version of GPT-2 is pre-trained. In other words, it is more comfortable to predict the 6th token with

$\{ \{1^{(4)}\}, \{1^{(4)}, 2^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\} \},$

instead of

$\{ \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\}, \{1^{(4)}, 2^{(4)}, 3^{(4)}, 4^{(4)}, 5^{(4)}\} \}.$

Finally, one might wonder why the GPT2 has to be trained in this way. It is about efficiency. For prediction in the former way, all the $1^{(4)}$'s, $2^{(4)}$'s, $3^{(4)}$'s, and $4^{(4)}$'s are identical to what were generated in previous iterations and thus can be cached and reused. However, this is not the case for the latter way. To see that, note that all the $1^{(4)}$'s, $2^{(4)}$'s, $3^{(4)}$'s, and $4^{(4)}$'s in the latter way contain information about the 5th token, which are not possible to compute in previous iterations.

In my understanding, causal masking is indeed for hiding future tokens. However, "future tokens" do not narrowly refer to what has not been generated yet. Instead, for a particular token, they refer to all the succeeding tokens including those that have already been generated and not.

0

Following the above comment by @Zhifeng , I don't get why are you assuming that the causal masking will make the first token not see the prebious prompt tokens? How will it generate then?

And if it's not the prompt token then those tokens have already been generated, i.e the querry will still see them, that's how decoder only model was trained on the first place. I don't get it why you have only put 1 token in layer 1.