6

In the transformer (or GPT/decoder only), at the end of the decoder blocks but before the final linear layer you have X vectors (for the X tokens at the input of the decoder). We then want to compute the probabilities for the next token of the sequence - what do we then feed to the linear layer? Is it the last embedding corresponding to the hidden state of the last token in the input sequence?

I've seen some tutorials on youtube on how to make mini gpts but I never quite understood why they feed the entire X vectors/hidden states at the end of the decoder blocks to the linear layer and not just the last vector/hidden state... Wouldn't you have X probability distributions when in reality you only want one? And if we do want the X probability distributions then wouldn't we be completely missing the point of the masked self attention since we would be trying to predict words that are already in the input sequence, so essentially "cheating"?

2 Answers2

2

Welcome to AI stack exchange!

I understand the confusion. Inference (next token prediction) seems really counterintuitive and inefficient for transformers. And it is! The transformer is very efficient during training because it can be parallelized. It is, however, inefficient at inference because it cannot be parallelized.

For transformer inference, you feed the context (your prompt) to the transformer model. It predicts the next word for each of the words in the prompt, but you only need the prediction for the last one.

A bit of pseudocode might help in understanding how a transformer can be used to generate new tokens:

# Start with some context of tokens
context = ...

Generate new tokens

for i in range(N_TOKENS_TO_GENERATE): prediction = transformer(context) # Get predictions for context next_token = multinomial(prediction.get_last()) # Sample from multinomial distribution context = concatenate((context, next_token)) # Create new context

Now, this is the intuitive way of doing it. There are most likely tons of small things you can do to optimize all of the stuff and make inference more efficient. However, you cannot get around having to feed the context in every time you add a new word/token. This is also why an application such as ChatGPT is generating stuff word for word.

A small note on the side: you talk about 'hidden-states' in the transformer, as if there is a recurrence going on (such as in GRUs/LSTMs/RNNs). However, transformers have no such recurrence and hidden-states and operate solely using the concept of attention (hence the paper's title 'attention is all you need', alluding to the fact that they don't use recurrence).

Hope this helps :)

Robin van Hoorn
  • 2,780
  • 2
  • 12
  • 33
0

No - each next-token prediction comes from a single one of the output vectors as you suspected. It has to, because otherwise there is no way to parallelize the predictions during training using a consistent set of parameters.

My understanding is from https://transformer-circuits.pub/2021/framework/index.html, section "High Level Architecture":

What you are referring to as "X vectors", they refer to as T(t) in that figure, of shape [n_context, n_vocab], which are logits. The input to the decoder, t is the [n_context, n_vocab] shaped tensor of one-hot encoded tokens (see the "Notation" section of that article at the end).

Note also that the original "Attention is All You Need" paper re-uses the embedding matrix as a transpose to de-embed each of the X vectors coming out of the final multi-head attention decoder layer. (See section 3.4 in the original paper)