I wanted to calculate the rough number of floating point operations (or float multiplication) that is required to generate a response to a prompt.
I am using llama 3.1, 70B as an example:
prompt length (using words instead of tokens for simplification): `What's the capital of russia`=>5
answer length: `the capital of russia is moscow` => 6
embedding size=8192
number of layers=80
number of heads=64
and so the amount of flops should be something like
Single QVK Generator Network Number of weights=(8192/64)*(8192/64)=16’384=>16’384Flops per generator
number of QVK Generators=8064=5120=>163845120=>83886080 Flops per token for QVK
dense network num weights=8192*8192=67108864
num dense networks=80=>67108864*80=5’368’709’120 Flops per token
attention calculation=(num_tokens^28192)2*80
totaling(num_tokens) (for 1 single inference)=(5’368’709’120+83’886’080)num_tokens+(num_tokens^28192)280
and since we have to do this multiple times (one for each token of the output), this is the total amount of tokens
totaling(num_tokens) for num_tokens in {6,7,8,9,10,11}
which gets me 279’116’840’960 FLOPS. Which I have no idea to verify, but it seems like an insane number of flops. Can someone tell me if and where I've made a mistake?