2

I'm training a text classifier in PyTorch and I'm experiencing an unexplainable cyclical pattern in the loss curve. The loss drops drastically at the beginning of each epoch and then starts rising slowly. However, the global convergence pattern seems OK. Here's how it looks:

loss curve (global)loss curve (zoom)

The model is very basic and I'm using the Adam optimizer with default parameters and a learning rate of 0.001. Batches are of 512 samples. I've checked and tried a lot of stuff, so I'm running out of ideas, but I'm sure I've made a mistake somewhere.

Things I've made sure of:

  • Data is delivered correctly (VQA v1.0 questions).
  • DataLoader is shuffling the dataset.
  • LSTM's memory is being zeroed correctly
  • Gradient isn't leaking through input tensors.

Things I've already tried:

  • Lowering the learning rate. Pattern remains, although amplitude is lower.
  • Training without momentum (plain SGD). Gradient noise masks the pattern a bit, but it's still there.
  • Using a smaller batch size (gradient noise can grow until it kinda masks the pattern, but that's not like solving it).

The model

class QuestionAnswerer(nn.Module):

    def __init__(self):
        super(QuestionAnswerer, self).__init__()
        self._wemb = nn.Embedding(N_WORDS, HIDDEN_UNITS, padding_idx=NULL_ID)
        self._lstm = nn.LSTM(HIDDEN_UNITS, HIDDEN_UNITS)
        self._final = nn.Linear(HIDDEN_UNITS, N_ANSWERS)

    def forward(self, question, length):
        B = length.size(0)
        embed = self._wemb(question)
        hidden = self._lstm(embed)[0][length-1, torch.arange(B)]
        return self._final(hidden)
David
  • 571
  • 3
  • 12

2 Answers2

1

It turns out that the zig-zag pattern is an inherent effect of using a word embedding layer. I don't fully understand the phenomenon, but I believe it has a strong correlation with the embeddings acting as a sort of memory slots, which can change relatively quickly, and the LSTM generating a summary of the sequence, so that the model can remember past combinations.

I found this plot of a training loss curve of word2vec and it exhibits the same per-epoch pattern.

word2vec loss

Edit

After conducting several experiments, I've isolated the causes. It seems that this is an indirect effect of having a large model capacity. In my case, I had too large word embeddings (size 1024) and too many classes (2002), which also increases model capacity, so the model was doing an almost per-sample learning. Reducing both resulted in a smooth-as-silk learning curve and a better generalisation.

David
  • 571
  • 3
  • 12
1

This weird pattern can be caused by a big learning rate. Check this: https://stackoverflow.com/a/49095437/13164928

ZappaBoy
  • 11
  • 1