0

I wanted to calculate the rough number of floating point operations (or float multiplication) that is required to generate a response to a prompt.

I am using llama 3.1, 70B as an example:

    prompt length (using words instead of tokens for simplification): `What's the capital of russia`=>5
answer length: `the capital of russia is moscow` => 6
embedding size=8192
number of layers=80
number of heads=64

and so the amount of flops should be something like

Single QVK Generator Network Number of weights=(8192/64)*(8192/64)=16’384=>16’384Flops per generator

number of QVK Generators=8064=5120=>163845120=>83886080 Flops per token for QVK

dense network num weights=8192*8192=67108864

num dense networks=80=>67108864*80=5’368’709’120 Flops per token

attention calculation=(num_tokens^28192)2*80

totaling(num_tokens) (for 1 single inference)=(5’368’709’120+83’886’080)num_tokens+(num_tokens^28192)280

and since we have to do this multiple times (one for each token of the output), this is the total amount of tokens

totaling(num_tokens) for num_tokens in {6,7,8,9,10,11}

which gets me 279’116’840’960 FLOPS. Which I have no idea to verify, but it seems like an insane number of flops. Can someone tell me if and where I've made a mistake?

nbro
  • 42,615
  • 12
  • 119
  • 217
user2741831
  • 135
  • 6

0 Answers0