1

I understand how causal masking in the self-attention layer of the decoder works and why we use it during training. What I want to ask is: should we use causal masking during inference ?

Consider a machine translation task where you need to translate the sentence
["I", "am", "going", "to", "the", "cinema"]
from english into german. During inference the encoder encodes the input sentence and the decoder starts generating the output sentence token by token. Let's say the following is generated until now:
["<START>", "Ich", "gehe", "ins"]
and you have to generate the next token. What you need to do is forward the currently generated sequence through the decoder and it will output a probability distribution for the next token. The question is: Do we need to use causal masking here?

Using the causal mask: $$ \text{mask} = \begin{pmatrix} 1 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 1 \end{pmatrix} $$ in the self-attention layer of the decoder would force each of the generated tokens to attend only to previous tokens. However, in my opinion, there is no need to use any masking here. The tokens that are already generated could simply attend to each other in order to better predict the next token.

However, reference implementations that I have been looking at continue using the causal masking during inference. See for example:

Is there a reason for using causal masking during inference?
Any thoughts on the matter would be appreciated. If you know of any research papers that discuss this topic or if you have seen somewhere an implementation that does not use casual masking during inference, please share a link.

pi-tau
  • 995
  • 6
  • 12

1 Answers1

1

I don't see why you should use a causal mask during inference.

The whole point is to "avoid cheating" when you already have the answer (during training), so that you can parallelize the training without having to do it step by step

During inference, there is obviously no concept of "cheating", as you generate a token, you append it to the already generated sentence, and you then feed this new sequence to the decoder to generate the next token

In my opinion, the only point for that causal masking, is either for code reuse, so you have to maintain a single transformer decoder code, or for something like padding if you want to do batched inference (not too sure about this in the link you provided)

Alberto
  • 2,863
  • 5
  • 12