5

How does a transformer leverage the GPU to be trained faster than RNNs?

I understand the parameter space of the transformer might be significantly larger than that of the RNN. But why does the transformer structure can leverage multiple GPUs, and why does that accelerate its training?

nbro
  • 42,615
  • 12
  • 119
  • 217
YoYO Man
  • 161
  • 1

3 Answers3

2

A recurrent neural network (RNN) depends on the previous hidden state from the previous time step. That is, an RNN is a function of both the data for the sequence at time $t$ and the hidden state from time $t-1$. This means that we cannot compute the $t$th hidden state without calculating the $t-1$th state, and the $t-1$th state without the $t-2$th state, and so on.

In contrast to this, a transformer is able to fully parallelise the processing of the sequence because it does not have this recursive relationship, i.e. a transformer is not a recursive function -- the recursive nature of the sequence is processed in other ways, such as through positional encoding. We can see this by the way self attention works.

If we first consider the general attention mechanism framework, then we have a query $q$ and a set of paired key-value tuples $\textbf{k}_1, ..., \textbf{k}_n$ and $\textbf{v}_1, ..., \textbf{v}_n$. In general, for each key, we will apply some attention function $\beta$ (such as a neural network) to obtain attention scores, $a_i = \beta(\textbf{q}, \textbf{k}_i)$. We then define an attention vector $\textbf{a}$ where the $i$th element is the $i$th attention score, and we take a softmax of this vector to obtain attention weights $\alpha_i$ where $\alpha_i$ is the $i$th element of $\mbox{softmax}(\textbf{a})$. The output of the attention mechanism for query $\textbf{q}$ is then the weighted sum $\sum_{i=1}^n \alpha_i \textbf{v}_i$.

Now that we have the necessary background for an attention mechanism, we can look at self attention which is the backbone of Transformer. If we have a sequence denoted by $\{\textbf{x}_1, ..., \textbf{x}_n\}$, then we can define a set of queries, keys and values to be these $\textbf{x}_i$ values. Note that previously we only had a single query, but here we will have multiple queries which is really how Transformer is able to parallelise the processing of the sequence. If we define $\textbf{Q}, \textbf{K}, \textbf{V}$ to be the matrices of the queries, keys and values (e.g. the $i$th row of $\textbf{Q}$ corresponds to the $i$th query, and similarly for the others). Self attention is as simple as performing attention over these query, key and values -- the name self comes from the fact that the queries, keys and values are all the same and represent the $i$th element of the sequence. Now, we can write the above attention mechanism as $a_{i, j} = \beta(Q, K)$ where we now have a matrix of attention scores (because we have $i$ queries and $i$ keys the matrix will be square), and we can take softmax row-wise to get the attention weights (again, this will be an $i\times i$ matrix). If we call the matrix of attention weights $\textbf{A}$ then the output of a self attention layer will be given by $\textbf{A} \textbf{V}$. As you can see, there is no recursive nature here and this is all parallelisable, e.g. it can be broken up and put onto multiple GPU's at the same time -- this would not be possible with an RNN as you would have to wait for the output of the previous layer.

David
  • 5,100
  • 1
  • 11
  • 33
0

The issue with Recurrent models is that they don't parallelization during training. Sequential models performs better with more memory but faces problem in learning long-term memory dependencies.

On the other hand Transformers take into account of self attention which boosts the speed of how fast the model can translate from one sequence to another and establishes dependencies b/w input and output and focus on relevant parts of the input sequence, which in turn eliminates recurrence and convolution unlike RNNs where sequential computation inhibits parallelization.

0

If a transformer model has a context size of 1024 tokens, then each training run serves as 1024 training example, where the model parallely tries to predict the next token for each of the sequence of the tokens, within their respective context.

For an RNN, to train out of same amount of data, it has to start from the first token and iteratively predict the next one. During each run, it loads the historical context into its hidden state, which is necessary because it predicts the next token based just on the previous token. The context is loaded in its hidden state.

So, to 'squeeze the juice' out of same amount of training data, it has to run a large number of times. To understand the $n$th token, it has to first do all the previous computations.

On the other hand, to predict the next token, the transformer model requires all the previous tokens up to its context window. Hence, where the Transformer models can start predicting given any context, RNN needs to first 'load' the context in its hidden state by iteratively going through the data, making predictions only when it has done its homework.

Prem
  • 101
  • 3