2

For a more concrete discussion, if we focus on the GPT-2 model, which is an auto-regressive model, I fully understand why we need masking for training. However, I need clarification on why we need to mask for inference. During inference, the model uses all information to predict the next token, so we should not need masking. Now, if we do not need masking for inference, I don’t understand the restriction of having a context window in LLMs. The only place that a context window appears when we configure the model's architecture is to define the triangle masking matrix with the size of context window * context window. But this masking matrix can be dynamically arranged as a matrix (vector) of ones for inference because we don’t need masking. Based on this argument, an LLM, aside from memory limitation, should be able to handle any context window. Please clarify where this argument falls short. In short the reason I ask this question is two folds:

  1. I see in the public code written for GPT, even in the inference time, this masking matrix is defined and let's say if the length of the input sequence is ten and we only need to predict the 11th token, still all subsequent probability distributions for 2nd, 3rd, ... 11th tokens are calculated, but all except the 11th probability distribution are dropped out. This process sounds so wasteful.

  2. It is not clear to me why LLMs are constrained by a context window.

F Gh
  • 21
  • 1

2 Answers2

0

Masking is used during inference primarily for consistency and efficiency, not because it's strictly necessary. The forward pass computation of the transformer is designed to process the input sequence in parallel for efficiency and the mask ensures that even though all tokens are computed at once, only valid attention is applied. For example, when generating the 11th token given a 10-token sequence, the model may internally compute the probability distributions for all tokens up to the 11th in one go. This is computationally more efficient on modern GPU/TPU hardware where matrix multiplications are highly optimized. Computing the entire matrix at once can leverage parallel processing better than computing just the 11th token's distribution. While this seems wasteful because only the 11th token's output is needed, this strategy avoids the overhead of changing the model’s computation graph dynamically or implementing additional checks. It is also convenient to reuse the same code path for both training and inference, avoiding potential bugs.

The fixed context window in GPT-2 and other Transformer-based models is determined by the model architecture. Specifically, it is tied to the fixed size of the positional encodings and the causal triangle masking matrices. Additionally for longer sequences, the memory and computational requirements grow quadratically, effectively limiting practical sequence length due to hardware constraints without alternative methods like sparse attention.

Future models might incorporate more dynamic attention mechanisms or memory-efficient architectures to mitigate these limitations, but for now these constraints are part of how models like GPT-2 are designed and optimized.

cinch
  • 11,000
  • 3
  • 8
  • 17
0

This process sounds so wasteful.

Masking is required during inference too because we only throw away all the computed tokens except 11th in last layer. Since GPT like models are stacked decoders, the 1st, 2nd ... 10th tokens are fed to the subsequent layer upto the pen-ultimate layer. If we turn off the masking, all the tokens upto the 10th will get context from all other tokens. This might seem okay, but will degrade the model performance since it was trained in an auto-regressive manner.

Also turning off the masking would not allow one to store the KV cache of the middle layers since it would change every time a new token is generated as now all the tokens will take into context the new token.

It is not clear to me why LLMs are constrained by a context window.

LLMs are restricted by context window because they are trained to reason over a limited input length. One can use a prompt length > context window but it would degrade the performance severely as LLMs don't generalize over a prompt of length larger than it is trained with. And the challenge with training with larger context length is the limited training data and memory requirements.

basujindal
  • 53
  • 1
  • 1
  • 6