13

Why LLMs learn so fast during inference, but, ironically, are so slow during training? That is, if you teach an AI a new concept in a prompt, it will learn and use the concept perfectly and flawless, through the whole prompt, after just one shot. Yet, if you train it in just a single sample, it will not influence its behavior at all - it will essentially forget. Why can't RNNs use whatever is happening during inference, rather than gradient descent, to update its weights and, thus, learn? In other words, can't the attention mechanism itself be used to update weights, rather than some cost function?

MaiaVictor
  • 405
  • 4
  • 11

4 Answers4

15

There is huge difference between what is happening with the information during training and during inference and one can not be used for the other.

Let me start with an analogy to the human brain (which is not a very good analogy and has several flaws, but it gives an intuition, I will late build on):
If I tell you "A palindrome is a number, a word or a sentence that reads the same forward and backward. One example palindrome is 15651" then you will now know what a palindrome is and you will be able to work with this new concept. If I tell the same to a newborn baby, it will not. It takes years to bring a baby to the point that it is able to understand my previous sentence.

Enough of the analogy. Let's have a look at the RNNs:
RNNs are neural networks with weights. Unlike some other networks, they have something that you can call an internal state. Weights and internal state are different:

  • The internal state serves as a memory that stores information of the previously processed information, e.g. a new concept that was explained earlier.
  • The weights define how new information changes the internal state and how input and internal state produce some output.

So untrained neural networks typically have randomly initialized weights. New input will then cause kind of arbitrary output und updates of the internal states. So if you give a new concept to an untrained neural network as an input, it will not the new concept, but update the internal state into meaningless numbers and produce random gibberish as output.

If you train the model, the weights are updates to serve a given purpose. This can be a relatively simple model, e.g. that detects whether a tweet is positive or negative. In this case, the network would only be trained on tweets and the internal state would only represent the positivity of the previous words and maybe if the last word was "not" to distinguish "I am happy" from "I am not happy". Probably much more detailed and not so easy to interpret, but something like this.

But if you build an LLM, you will train on much more heterogeneous data and for tasks that will involve to understand new concepts. In this case the weights of the model will be learned in a way, that the network can process new concepts and store the essence of the concept in the internal state.

In short: teaching the network new concepts (which is an update of the internal state) can only happen, because before a long training of the weights enables the LLM to do so.

Avoiding Backpropagation: There is some recent work that explores new ways of training that avoid the costly backpropagation and instead preforms two forward steps. The forward-forward Algorithm, but to my knowledge it is not used for LLMs, yet. And even if, one would still need to train it on a huge amount of data to learn some weights that allow the network to process new concepts as input.

Laurel
  • 143
  • 5
Broele
  • 561
  • 4
  • 12
11

They are not "learning" during inference at all.

Learning is the process of updating the weights of the model (to lower loss). This does not happen during inference. The model weights stay the same.

When you are "teaching an AI a new concept", you are just giving it some context, which improves its ability to answer what happens next. But weights do not get updated. That is the computationally expensive part.

In a meta context, I guess you can call this "learning", and this is effectively what is being done with Microsoft's new venture with GPT. They are letting it search things on the fly. Probably a lot of interesting research will surface using such techniques soon :)

shatz
  • 154
  • 4
4

As pointed out by others, what you call "learning" at inference, is nothing more than providing more context. The model can indeed memorize in its short-term, but it is only working for the current task at hand. You suggest that we could make a model with an infinite contextual memory, but then it would mix up all tasks together. It would literally be like if you had to recite all the numbers you ever calculated or counted or saw before starting a new calculation.

Hence, contextualization is only useful for short-term tasks, and it works only thanks to the slow learning phase you have to do the first time around, which is more formally called the "convergence process".

So, what you are looking for is in fact to make the convergence process faster, and more precisely a one-shot or zero-shot learning. If you don't just look at LLMs (Large Language Models) and RNNs (Recursive Neural Networks), there are a lot of other AI models that can do one-shot or even zero-shot learning, such as memory models like grippon-berrou neural network. One-shot learning can learn the first time they see an example, and generalize over it. Zero-shot learning can even learn without being presented some examples, by generalizing from others, or by transferring knowledge from another field.

For example, Text2Video-Zero is a recently published text to video generator, which did NOT learn from any video, but instead reused the weights from Stable Diffusion trained on still images. What this algorithm does is that it can cleverly generalize learning from still images into a coherent sequence of images with the same style, hence mimicking motion, with no additional cost. Of course, it's not completely zero-shot, because it has to be provided with a Stable Diffusion weights model first, but essentially zero-shot learning means that you can reuse one model that was made for one purpose for another purpose, for free (ie, you can directly infer, no need to re-learn anything).

Technically, One/Zero-shot learning typically requires another kind of architecture, more brain-like (ie, with discrete 0/1 synaptic weights). The long convergence processes are usually required by networks using floating weights (ie, the McCulloch-Pitts neurons). Because floating weights are not at all biologically plausible, they are a mathematical abstraction that synthesizes several biological functions of biological neural networks into fewer, more amenable to programming abstractions.

Likewise, convolution layers in CNNs (convolutional neural networks) are another abstraction of how biological systems integrate big populations of neurons, but here we can use a much smaller population of artificial neurons, and use more optimized instructions sets to do the same work as the brain does. You have to keep in mind that for a lot of purposes in AI, current computers are much less efficient than the human brain, hence why all these highly synthetic reproductions, more optimized for the machine but very remote from how real biological systems work, are necessary. Here, long convergence (ie, long learning) is an unavoidable artifact from how we model our artificial neurons and synapses, with floating numbers instead of discrete (binary), and with mathematical functions for integration instead of analog biological integration (which is both more fine grained and simpler than numerical functions, see for example the videos by Veritasium about analog computers, biological systems have similar properties and advantages).

RNNs are a kind of the opposite approach and problem, because they use a more biologically plausible property, recursivity, but the problem is that we have a hard time defining artificial systems that are efficient at learning recursive networks. So here, it's the opposite of what can be observed with CNNs and LLMs: the long convergence is due to current science providing inefficient learning algorithms when recursivity is involved. The last few years saw tremendous progress on this, with very clever algorithms, but it's still very far from how biological systems can neatly manage recursivity.

All that is to say that, to answer directly your question, why the current LLM and RNN models can't learn in zero/one-shot from the get-go: it's because nobody found a way to mathematically formulate such a model. Maybe someone will be able to in the near future, maybe it will take decades, but for now, it's the slow convergence LLM and RNN models that work, it's the ones that provide you with the hyped tools such as ChatGPT.

Personally, I think we won't get there until we find how analog biological neural system work, and then we need to develop new computer technologies to mimic those. There is already a lot of work towards these, with biological neurons reprogramming by ARN signalling or mixing them with silicon neurons, but it's still far from the "real deal". There are at least hundreds of different types of neurons, and there are many other neural cells types with not completely understood functions. We are far from fully understanding biological neural systems, but progress is continuous and steady.

Disclaimer: I am both an AI researcher and a clinical neuroscientist and I studied some computational neuroscience.


/EDIT: A small update to extend my explanation above for the technically and philosophically inclined ones: learning at its most fundamental level can be defined as the ability of a system to modify its structure to reflect some input signal, and memory being the system itself that can modify its structure according to input signals. In biological systems, there are two types of memory: short-term and long-term. Recent artificial recursive neural network models try to mimic this, with the very famous LSTM model (Long-Short Term Memory), itself a precursor of the GPT models. By convention, in machine learning we call "learning" the tweaking of the weights, ie, the long-term memory. But there is also indeed a short-term memory which has its own weights, but AI researchers don’t call this process learning, although it technically is by all standards, the only difference being the exact method used and the length of time the memory is retained.

And just like there are models that modify/learn short-term memory at inference but not long-term memory, there are models that tweak their long-term memory at inference, notably bayesian models, as often used for weather forecasting.

So why LLMs and RNNs learn fast during inference is because they are designed to only learn short-term memory, so that the big lot of weights of long-term memory were learnt beforehand. But future improvements of the tech may very well allow to design networks that also learn long-term memory "online", in real-time, in a stochastic manner with a guarantee of convergence.

gaborous
  • 466
  • 3
  • 4
3

It's kind of like short-term memory versus long-term memory. Giving a language model a small amount of information at inference time allows it to use that information, and so you might say that the model has "learned" that information, but this "learning" isn't really useful in the long term.

For RNNs, the problem is that the state vector only contains a limited amount of information. You can tell an RNN something once, but as you give it more information, it will forget what you told it previously. So if you have a large amount of information that you want your RNN to be able to access, then providing that information as input during inference won't do the trick; you need to train it.

For transformers, the problem is that the amount of time it takes the model to process a token of input is proportional to the number of tokens it's already processed. If you have just a small amount of information that you want the transformer to learn, that's not a problem, but if you try to give a transformer a very large amount of information as input, that will make inference very slow.

Note that language models are sometimes permanently "taught" things by means of input instead of training. For example, it's been reported that ChatGPT and Bing Chat have a hard-coded prompt that's always present at the beginning of the input, and which contains some information about what the developers want the model to do.

Sophie Swett
  • 151
  • 6