Let's focus on the case of decoder-only transformers, where I am using algorithm 10 from "Formal Algorithms for Transformers" by Mary Phung and Marcus Hutter as a reference.
: https://i.sstatic.net/ZWC9o.png
Previously I thought that the maximum context length is very much built into the transformer, for example as the dimension of a layer of weights. After studying this algorithm I am surprised because it seems more like an artificial restriction! Because this is a topic of active research I would like to know if I am misunderstanding something.
The way I see it, if I had access to the weights of GPT-2 right now, I could almost execute it on any number of tokens I like right away (If I had sufficient memory to compute this). The MHA algorithm is just carried out over a larger sequenece. There are only two issues, which are points where the context window $l_{max}$ appears:
- The positional encoding has only $l_{max}$ positions
- During training the weights were never optimized to attend over more than $l_{max}$ tokens.
But these issues seem rather easy to resolve:
Use some positional encoding which has infinitely many positions. The first encoding vectors are nicely spread around while the later ones are closer to eachother, due to the nature of fitting an infinite sequence of vectors into more or less a unit ball/sphere. But this is not an issue: It is natural for the positional encoding to become more vague as the token is further and further in the past.
Train 50% on context lengths around $l_{max}$, 25% on context lengths around $2 l_{max}$, 12.5% on context lengths around $4 l_{max}$ and so on...
I can imagine the following issues appearing:
A) Memory becomes larger than what is available on a single "unit" (GPU?) so you have to start moving data back and forth to execute your transformer, which is just terribly inefficient during training and also inference, so it is really pointless to train on such large context windows.
B) Perhaps the transformer just doesn't learn well with this procedure for some reason.
These issues are still rather "soft" issues though. As far as I can tell, I could use the architecture of GPT-2 (modified positional encoding) to create 1000000 context window LLMs, in theory. So, am I missing something?
Thank you!
