26

Can the decoder in a transformer model be parallelized like the encoder?

As far as I understand, the encoder has all the tokens in the sequence to compute the self-attention scores. But for a decoder, this is not possible (in both training and testing), as self-attention is calculated based on previous timestep outputs. Even if we consider some techniques, like teacher forcing, where we are concatenating expected output with obtained, this still has a sequential input from the previous timestep.

In this case, apart from the improvement in capturing long-term dependencies, is using a transformer-decoder better than say an LSTM, when comparing purely on the basis of parallelization?

nbro
  • 42,615
  • 12
  • 119
  • 217
shiredude95
  • 363
  • 1
  • 3
  • 6

4 Answers4

20

Can the decoder in a transformer model be parallelized like the encoder?

Generally NO:

Your understanding is completely right. In the decoder, the output of each step is fed to the bottom decoder in the next time step, just like an LSTM.

Also, like in LSTMs, the self-attention layer needs to attend to earlier positions in the output sequence in order to compute the output. Which makes straight parallelisation impossible.

However, when decoding during training, there is a frequently used procedure which doesn't take the previous output of the model at step t as input at step t+1, but rather takes the ground truth output at step t. This procudure is called 'Teacher Forcing' and makes the decoder parallelised during training. You can read more about it here.

And For detailed explanation of how Transformer works I suggest reading this article: The Illustrated Transformer.

Is using a transformer-decoder better than say an lstm when comparing purely on the basis of parallelization?

YES:

Parallelization is the main drawback of RNNs in general. In a simple way, RNNs have the ability to memorize but not parallelize while CNNs have the opposite. Transformers are so powerful because they combine both parallelization (at least partially) and memorizing.

In Natural Language Processing for example, where RNNs are used to be so effective, if you take a look at GLUE leaderboard you will find that most of the world leading algorithms today are Transformer-based (e.g BERT by GOOGLE, GPT by OpenAI..)

For better understanding of why Transformers are better than CNNs I suggest reading this Medium article: How Transformers Work.

HLeb
  • 599
  • 5
  • 10
9

Can the decoder in a transformer model be parallelized like the encoder?

The correct answer is: computation in a Transformer decoder can be parallelized during training, but not during actual translation (or, in a wider sense, generating output sequences for new input sequences during a testing phase).

What exactly is parallelized?

Also, it's worth mentioning that "parallelization" in this case means to compute encoder or decoder states in paralllel for all positions of the input sequence. Parallelization over several layers is not possible: the first layer of a multi-layer encoder or decoder still needs to finish computing all positions in parallel before the second layer can start computing.

Why can the decoder be parallelized position-wise during training?

For each position in the input sequence, a Transformer decoder produces a decoder state as an output. (The decoder state is then used to eventually predict a token in the target sequence.)

In order to compute one decoder state for a particular position in the sequence of states, the network consumes as inputs: 1) the entire input sequence and 2) the target words that were generated previously.

During training, the target words generated previously are known, since they are taken from the target side of our parallel training data. This is the reason why computation can be factored over positions.

During inference (also called "testing", or "translation"), the target words previously generated are predicted by the model, and computing decoder states must be performed sequentially for this reason.

Comparison to RNN models

While Transformers can parallelize over input positions during training, an encoder-decoder model based on RNNs cannot parallelize positions. This means that Transformers are generally faster to train, while RNNs are faster for inference.

This observation leads to the nowadays common practice of training Transformer models and then using sequence-level distillation to learn an RNN model that mimicks the trained Transformer, for faster inference.

Mathias Müller
  • 361
  • 3
  • 13
2

Can't see that this has been mentioned yet - there are ways to generate text non-sequentially using a non-autoregressive transformer, where you produce the entire response to the context at once. This typically produces worse accuracy scores because there are interdependencies within the text being produced - a model translating "thank you" could say "vielen danke" or "danke schön" but whereas an autoregressive model can know which word to say next based on previous decoding, a non-autoregressive model can't do this, so also could produce "danke danke" or "vielen schön". There is some research that suggests you can close in on the accuracy gap though: https://arxiv.org/abs/2012.15833.

Ben
  • 91
  • 1
0

Pytorch implemintation of TransformerDecoderLayer doesn't use predicted output of itself (can use or not labels) but predict all sequence at once. BUT if you are using stack of TransformerDecoderLayers (num of Decoders > 1) they use output of previous layer even if tgt sequence (teacher forcing) is used.

Dima
  • 1