1

I’ve been reading about non-autoregressive models (like NATs or diffusion models) and how they can generate outputs in parallel instead of step-by-step like autoregressive models. That sounds fast in theory, but in practice, they often need multiple refinement steps (e.g., denoising in diffusion models or iterative decoding in NATs) to get good quality.

So I’m wondering:

  1. Are there any benchmarks that show how many refinement steps (and the corresponding time) are needed to match autoregressive model accuracy, and how the accuracy scales with the number of steps?

  2. More practically, to reach the same level of accuracy, how does the total inference time (including all the refinement iterations) compare to autoregressive models that just decode one token at a time?

Have any companies (like Google, Meta, DeepMind, etc.) shared real-world benchmarks, blog posts, or papers on this?

Thanks!

LearnerAL
  • 11
  • 1

1 Answers1

2

Efficiency has so many variables that is hard to have a fair comparison. However, consider having a 100x100 pixel image, which is a reasonably small one.

Now, for diffusions, 100x100 is not much, you can have a convolutional U-Net that handles it without crazy computational requirements, or some fancy ViT with patches. Ideally, you need (roughly speaking) 20 steps of diffusion to generate something nice.

For autoregressive, you need a forward pass for each output, which, considering that you also have RGB (thus 3 channels), you have 30.000 (3x100x100) forward passes to do to generate an image. If this is not already creazy enough, if you want to use SoTA models for autoregressive tasks (Transformers), you need to handle 30.000 context size, which for the attention implies $30.000^2 = 900.000.000$ floating point values, which each occupies 4 bytes, thus you have a 3.6 Gigabytes attention matrix.

Now, obviously it's extremely application dependent, you should also consider what type of GPUs you are using, if you can have fancy attention mechanism in the transformer, if you can use fp8 instead of fp32 for training and all that kind of stuff... yet still you have 20 forward passes of some U-Net against 30.000 forward pass of a transformer

Alberto
  • 2,863
  • 5
  • 12