2

I am trying to train a decoder-only transformer model. The dataset is left-padded to a fixed length so sequences of tokens can be batched. However, when I try to pass input through a multi head attention layer, with both a key padding mask and causal attention mask, I get nan values in the output. I believe this is because the left padding causes a full row to be masked out in the attention mask for the padding tokens, leaving them with nothing to attend to.

Minimal example with sequence of length 5:

mha = nn.MultiheadAttention(8, 2, batch_first=True)
x = torch.randn(1, 5, 8)
attn_mask = ~torch.tril(torch.ones(5, 5)).bool()
key_padding_mask = torch.tensor([[True, False, False, False, False]])
print(attn_mask)
print(key_padding_mask)

gives me what I expect for the attn_mask and key_padding_mask:

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])
tensor([[ True, False, False, False, False]])

The combined mask has the first row completely set to True, so every token is masked:

tensor([[ True,  True,  True,  True,  True],
        [ True, False,  True,  True,  True],
        [ True, False, False,  True,  True],
        [ True, False, False, False,  True],
        [ True, False, False, False, False]])

Because of this, the entire row becomes -inf and the softmax turns them into nan. This problem doesn't appear with right padding, since the column of True appears at the end and no row has all True. How does training a causal transformer work with left-padded data without causing the output to have nan entries?

xnsc
  • 21
  • 1

0 Answers0