As I understand it, the forward pass for a transformer model looks as follows:
x += self_attention(x)
x = layernorm(x)
x += ffn(x)
Breaking that down a bit (excuse the hand-waving, this is meant to be illustrative):
def self_attention(x):
qkv_list = get_qkvs(x)
heads = [softmax(q @ t(v) / sqrt(d))@v for (q,k,v) in qkv_list]
concatenated_heads = concat(heads)
projected_values = concatenated_heads @ W_O # projection matrix for concatenated heads
return projected_values
and
def ffn(x):
return (relu(x @ W_1 + b1))@W_2 + b2
Perhaps I am missing something obvious, but I note that if one were to drop the layernorm in between the self-attention and FFN module, you have two linear projections in a row ($W_O$, $W_1$). $W_O$ is a costly matrix: it should be dimensions (n_head * d_head x d_embed) = (d_embed x d_embed), so contributes (n_layers * d_embed^2) parameters over the course of the network.
My question is: if you drop the layernorm, it appears the $W_O$ matrix would be entirely redundant, and you could save a massive amount of compute by essentially fusing these two operations into a single learned matrix (that matrix would both be mixing the streams from the different heads like $W_O$ does, and up-projecting into the FFN embedding layer like $W_1$ does). This means either:
- The layernorm is incredibly important, and I assume this has been shown somewhere empirically?
- I've missed something completely obvious about the forward pass / even if you dropped the layernorm the $W_O$ matrix is not redundant for some reason
Would love to be corrected!