I'm studying the series of Wav2Vec papers, in particular, the vq-wav2vec and wav2vec 2.0, and have a problem understanding some details about the quantization procedure.
The broader context is this: they use raw audio and first convert it to "features" $z$ via a convolutional network. Then they project any feature $z$ to a "quantized" element $\hat{z}$ from a given finite codebook (or concatenation of finitely many finite codebooks). To find $\hat{z}$, they compute scores $l_j$ for each codebook entry $v_j$, convert these scores to Gumbel-Softmax probabilities $p_j$ (using a formula which is not deterministic, the formula involves random choices of some numbers from some distribution) and then use these probabilities $p_j$ to choose $\hat{z}$. Further stages of the pre-training pipeline are trained to predict $\hat{z}'s$ by either predicting "future" from the "past", or "reconstructing masked segments".
My question is this is about this sentence:
During the forward pass, $i = \text{argmax}_j p_j$ and in the backward pass, the true gradient of the Gumbel-Softmax outputs is used.
- I have trouble seeing what exactly is happening in the loss function and back-propagation. Could someone please help me to break this down into details?
My mental attempts to make sense out of it (I'm using the notation $\hat{z}$ for quantized vectors, in the second paper they use $q$)
(1) I would say that during the forward pass, in the Gumbel-Softmax, random variables from the Gumbel-distribution $n_j$ are sampled every time (for every training example) to compute the Gumbel-softmax probabilities $p_j$.
(1a) In the back-propagation, these $n_j$'s are kept constant, and $p_j$ is treated as a function of $l_j's$ only.
(2) The loss function has 2 parts here, Contrastive loss and Diversity loss.
(2a) Based on the description, I would say that in the contrastive loss, the "sampled" vectors $\hat{z}_j$ are used, and probabilities never appear (even not in back-propagation of this part of the loss).
(2b) I would believe that in the gradient of the Diversity loss, which only uses probabilities $p_{g,v}$, that here the gradient or the loss actually is used, as this is responsible for maximizing the entropy. This part of the gradient probably does not use the sampled values $\hat{z}_j$.
Is this approximately correct?
If yes, then I still fail to understand what exactly is happening in the vq-wav2vec paper. The sentence
During the forward pass, $i = \text{argmax}_j p_j$ and in the backward pass, the true gradient of the Gumbel-Softmax outputs is used.
is there as well, but I cannot see any part of the loss function (in this paper) where the probabilities are explicitly used (such as the diversity loss).