2

I am fine-tuning a mistral-7b with Hugging Face peft and quantization. In my training loop, I am printing the gradient values for each batch, which seem a bit unusual.

# Print gradients
for name, param in model_init.named_parameters():
      if param.grad is not None:
           print(f'Gradient for {name}: {param.grad.norm()}')

I am trying to understanding, why all the gradients values are 0s except for the #1 iteration (starting from 0th).

iteration 0
...
...
Gradient for base_model.base_model.model.model.layers.31.self_attn.q_proj.lora_B.default.weight: 0.0
Gradient for base_model.base_model.model.model.layers.31.self_attn.k_proj.lora_A.default.weight: 0.0
Gradient for base_model.base_model.model.model.layers.31.self_attn.k_proj.lora_B.default.weight: 0.0
Gradient for base_model.base_model.model.model.layers.31.self_attn.v_proj.lora_A.default.weight: 0.0
Gradient for base_model.base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight: 0.0

iteration 1 ... ... Gradient for base_model.base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight: 0.0142822265625 Gradient for base_model.base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: 0.0 Gradient for base_model.base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: 3.953125 Gradient for base_model.base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: 0.0 Gradient for base_model.base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: 0.185546875

iteration n ... ... Gradient for base_model.base_model.model.model.layers.31.self_attn.k_proj.lora_A.default.weight: 0.0 Gradient for base_model.base_model.model.model.layers.31.self_attn.k_proj.lora_B.default.weight: 0.0 Gradient for base_model.base_model.model.model.layers.31.self_attn.v_proj.lora_A.default.weight: 0.0 Gradient for base_model.base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight: 0.0

Isn't this a bit unusual? Is it the vanishing gradient problem?

For context,

# custom class
class BinaryClassification(nn.Module):
def __init__(self, base_model):
    super().__init__()
    self.base_model = base_model
    self.dropout = nn.Dropout(0.05)
    #self.classifier = nn.Linear(hidden_size, 1)
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()

def forward(self, x):
    outputs = self.base_model(x)
    dropout_output = self.dropout(outputs.logits)
    relu_output = self.relu(dropout_output[:, -1, :])
    probs = self.sigmoid(relu_output) # Apply sigmoid to logits to get probabilities
    #print('forward probs', probs)
    return probs

model_init = BinaryClassification(peft_model)

optimizer

criterion = torch.nn.BCELoss() optimizer = torch.optim.AdamW(model_init.parameters(), lr=0.001, eps=1e-08, weight_decay=0.001) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

Loss

def calc_loss_batch(input_batch, target_batch):

output = model_init(input_batch)

# Reshape target to match the shape of probs
target_batch = target_batch.unsqueeze(1)

if output.shape != target_batch.shape:
    raise Exception("Shape mismatch between input logits and target label")

# Logits of last output token
loss = criterion(output, target_batch)

return loss

```

nbro
  • 42,615
  • 12
  • 119
  • 217
kms
  • 121
  • 3

1 Answers1

1

Normally, an adapter to perform binary classification might be similar to what you've described, but with the classifier and without the relu before the sigmoid (sigmoid after relu will output >=0.5). You may also want to try using the output embedding layer's input rather than its output (vocab logits) if possible, and remove the sigmoid and use BCEWithLogitsLoss (it can be more numerically stable).

For this sort of model, I think it's unlikely that the gradients will vanish like that if the adapter and training procedure is correct. As the adapter is currently, for base model logit outputs less than 0, which in some models could consistently be the case, the gradient will be 0 and backpropagate as 0s no matter the loss. If this is the issue that you're facing, then you would find 0.5s for the forward probabilities. Additionally, dropout_output[:, -1, :] looks like it's sampling the last output token logits. You probably want to check that the last output token is not degenerate in some way (like almost always being . or here a pad token). If neither of these are the case perhaps there is some issue with your loss function, labels or inputs, however, I think even if those were incorrect, they are not likely to be causing 0 gradients here.

Greg S
  • 111
  • 2