6

As far as I understand, Transformer's time complexity increases quadratically with respect to the sequence length. As a result, during training to make training feasible, a maximum sequence limit is set, and to allow batching, all sequences smaller are padded.

However, after a Transformer is trained, and we want to run it on a single sequence at inference time, the computational costs are far less than training. Thus, it seems reasonable that I would want to run the transformer on a larger input sequence length during inference time. From a technical perspective, this should be feasible.

I keep reading online that a Transformer cannot be run on a sequence size larger than the one seen during training. Why is this? Is it because the network weights will be unfamiliar with sequences of this length? Or is it more fundamental?

hanugm
  • 4,102
  • 3
  • 29
  • 63
chessprogrammer
  • 3,050
  • 2
  • 16
  • 26

3 Answers3

4

Transformer models have limited sequence length at inference time because of positional embeddings. But there are workarounds.

Self-attention in transformer does not distinguish the order of keys/values, it works as if the sequence is a bag of words.

So to expose the sequence order to the model, one typically adds an extra "positional embedding" vector to each input token embedding. This extra positional information then allows the model to construct primitives like "attention head that looks at previous token".

There are several variants of how to do this exactly.

One way is to use an extra trainable parameter matrix [L,M] where L is maximum input length and M is model dimension. After training a model this way, it simply becomes impossible to embed tokens with position > L. So we get a hard limit for which there is no workaround.

Another way is to use a non-trainable matrix [L,M] initialized with a specially crafted set of sinusoids (original "attention is all you need" paper does this). A model trained this way would not have a length limit, because this matrix of sinusoids you can extend to arbitrary size.

Yes another way is to use "relative positional encoding". With this, for each pair of tokens you take relative position, which would be in range [-L ... +L], and embed that instead. Then you inject this vector into attention layer in the right way (section 3.3 in https://arxiv.org/pdf/1901.02860.pdf).

Now you still kind of have a limit of [-L, L]. But you can always "clip" relative position to this range, and pretend that all pairs of tokens at distance >= L have the same relative position ("very far from each other"). And this allows to run inference with longer inputs.

1

To some extent, this is true; The piecewise feedforward layers can be added or subtracted to fit the sequence length. The matrix operations can similarly be scaled to fit sequence length.

However, the computational complexity comes from the matrix operations in the attention layer. Those are not trained; There are no trained parameters in the attention mechanism (see figure 2 and equation (1) in Vaswani et al). So, those have to be computed during inference as well.

Another challenge would be the output layer. That layer is a regular feedforward layer and thus has a fixed input size; That is, you cannot add new parameters during inference.

Of course, there is a caveat to this; There are now transformers now that allow recurrence, such as Transformer-XL and Memformer. These do, in a way, allow longer input sequences than "max sequence length".

Avatrin
  • 556
  • 5
  • 11
1

If you used something like Rotary Positional Embedding -- such as the LLama3 family of models -- then technically speaking, you could do inference with arbitrary long sequence (assuming you have enough computational resources). The transformer itself does not have any inherent limitation.

The real question is more like: will it work well, if no such sequence has been ever seen in training? Will it understand how to interpret a token -2000 from current position, if in training, it has seen at most tokens -1000 from current position?

I tried to overwrite the max_length configuration in llama3-7b, just for fun, and got this warning: This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.

It is common to use shorter sequences on inference, but not longer.

Peter Franek
  • 384
  • 1
  • 4
  • 14