I am working on fine-tuning BLIP-2 on the RSICD dataset using LoRA. I am working on colab, using an A100. I am strangely finding that when I set the learning rate in the code below, it has no effect. I can set it to 10^55, or I can set it to 10^(-55), and the loss still jumps around at roughly the same sized intervals.
To show what I'm talking about, I created a public colab notebook here. In this notebook I create a ExponentialLRScheduler with gamma=.1, updated on every individual step i.e. the loss should be changing by an order of magnitude less in each step, just to demonstrate my point. When I print the loss and the learning rate, the learning rate indeed decreases by a factor of 10 in each step, and yet the loss continues to jump around in the same jump sizes, indicating that the loss rate being printed is not actually taking effect.
Does anyone know what might be causing this? Why is the loss still jumping around so much even with a learning rate that is essentially 0?
My code is also copied below.
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
device_map="auto",
torch_dtype=torch.bfloat16
)
processor.num_query_tokens = model.config.num_query_tokens
image_token = AddedToken("<image>", normalized=False, special=True)
processor.tokenizer.add_tokens([image_token], special_tokens=True)
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64)
model.config.image_token_index = len(processor.tokenizer) - 1
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 10
learning_rate = 2e-5
batch_size = 16
lora_alpha = 32
lora_dropout = 0.05
lora_dim = 8
targetData=torch.load("/content/drive/Shareddrives/TEMFOM/target_data1.pt")
config = LoraConfig(
r=lora_dim,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
target_modules=["q_proj", "k_proj"]
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
def fine_tune(model, train_dataloader, optimizer, n_epochs, model_name="fine-tuned"):
for epoch in range(0, n_epochs):
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids
)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("Loss:", loss.item())
model.save_pretrained(direct + model_name)
return model
train_dataset = ImageCaptioningDataset(targetData, processor)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=batch_size,
collate_fn=collate_fn
)
fine_tuned_model = fine_tune(model, train_dataloader, optimizer, n_epochs, model_name="fine-tuned")
An example of the output (which you can also see from the colab notebook) is below. As you can see, there are still very large jumps in the loss rate even as the learning rate decreases below 10^-50.
Epoch: 0
Loss: 5.318110466003418
learning rate before: 2e-10
learning rate after: 2.0000000000000002e-11
Loss: 4.520220756530762
learning rate before: 2.0000000000000002e-11
learning rate after: 2.0000000000000004e-12
Loss: 4.288538455963135
learning rate before: 2.0000000000000004e-12
learning rate after: 2.0000000000000006e-13
Loss: 4.518784999847412
learning rate before: 2.0000000000000006e-13
learning rate after: 2.0000000000000006e-14
Loss: 4.479536533355713
learning rate before: 2.0000000000000006e-14
learning rate after: 2.0000000000000005e-15
Loss: 4.55037260055542
learning rate before: 2.0000000000000005e-15
learning rate after: 2.0000000000000007e-16
Loss: 4.3671770095825195
learning rate before: 2.0000000000000007e-16
learning rate after: 2.0000000000000008e-17
Loss: 4.405301570892334
learning rate before: 2.0000000000000008e-17
learning rate after: 2.000000000000001e-18
Loss: 4.015200138092041
learning rate before: 2.000000000000001e-18
learning rate after: 2.000000000000001e-19
Loss: 4.679333209991455
learning rate before: 2.000000000000001e-19
learning rate after: 2.000000000000001e-20
Loss: 3.9693050384521484
learning rate before: 2.000000000000001e-20
learning rate after: 2.0000000000000013e-21
Loss: 4.0665154457092285
learning rate before: 2.0000000000000013e-21
learning rate after: 2.0000000000000015e-22
Loss: 4.336864471435547
learning rate before: 2.0000000000000015e-22
learning rate after: 2.0000000000000017e-23
Loss: 4.552571773529053
learning rate before: 2.0000000000000017e-23
learning rate after: 2.0000000000000017e-24
Loss: 4.480626106262207
learning rate before: 2.0000000000000017e-24
learning rate after: 2.0000000000000017e-25
Loss: 4.519588947296143
learning rate before: 2.0000000000000017e-25
learning rate after: 2.0000000000000018e-26
Loss: 4.422896862030029
learning rate before: 2.0000000000000018e-26
learning rate after: 2.000000000000002e-27
Loss: 3.851675033569336
learning rate before: 2.000000000000002e-27
learning rate after: 2.000000000000002e-28
Loss: 3.561893939971924
learning rate before: 2.000000000000002e-28
learning rate after: 2.000000000000002e-29
Loss: 4.885611534118652
learning rate before: 2.000000000000002e-29
learning rate after: 2.0000000000000023e-30
Loss: 4.571497440338135
learning rate before: 2.0000000000000023e-30
learning rate after: 2.0000000000000024e-31
Loss: 4.3077521324157715
learning rate before: 2.0000000000000024e-31
learning rate after: 2.0000000000000026e-32
Loss: 3.834765911102295
learning rate before: 2.0000000000000026e-32
learning rate after: 2.000000000000003e-33
Loss: 4.235876560211182
learning rate before: 2.000000000000003e-33
learning rate after: 2.000000000000003e-34
Loss: 4.281957626342773
learning rate before: 2.000000000000003e-34
learning rate after: 2.000000000000003e-35
Loss: 4.0060648918151855
learning rate before: 2.000000000000003e-35
learning rate after: 2.0000000000000032e-36
Loss: 4.274528503417969
learning rate before: 2.0000000000000032e-36
learning rate after: 2.0000000000000035e-37
Loss: 4.298925876617432
learning rate before: 2.0000000000000035e-37
learning rate after: 2.0000000000000036e-38
Loss: 4.506286144256592
learning rate before: 2.0000000000000036e-38
learning rate after: 2.0000000000000038e-39
Loss: 4.11824369430542
learning rate before: 2.0000000000000038e-39
learning rate after: 2.000000000000004e-40
Loss: 4.141360759735107
learning rate before: 2.000000000000004e-40
learning rate after: 2.000000000000004e-41
Loss: 4.402781963348389
learning rate before: 2.000000000000004e-41
learning rate after: 2.0000000000000042e-42
Loss: 4.450037002563477
learning rate before: 2.0000000000000042e-42
learning rate after: 2.000000000000004e-43
Loss: 4.273048400878906
learning rate before: 2.000000000000004e-43
learning rate after: 2.0000000000000041e-44
Loss: 4.774006366729736
learning rate before: 2.0000000000000041e-44
learning rate after: 2.0000000000000043e-45
Loss: 3.908968687057495
learning rate before: 2.0000000000000043e-45
learning rate after: 2.0000000000000043e-46
Loss: 3.9161949157714844
learning rate before: 2.0000000000000043e-46
learning rate after: 2.0000000000000043e-47
Loss: 4.0039896965026855
learning rate before: 2.0000000000000043e-47
learning rate after: 2.0000000000000045e-48
Loss: 3.8200762271881104
learning rate before: 2.0000000000000045e-48
learning rate after: 2.0000000000000044e-49
Loss: 4.692992687225342
learning rate before: 2.0000000000000044e-49
learning rate after: 2.0000000000000045e-50
Loss: 4.407190799713135
learning rate before: 2.0000000000000045e-50
learning rate after: 2.0000000000000048e-51
Loss: 4.065435886383057
learning rate before: 2.0000000000000048e-51
learning rate after: 2.000000000000005e-52
Loss: 3.7482211589813232
learning rate before: 2.000000000000005e-52
learning rate after: 2.000000000000005e-53
Loss: 4.571844100952148
learning rate before: 2.000000000000005e-53
learning rate after: 2.000000000000005e-54
Loss: 4.8389458656311035
learning rate before: 2.000000000000005e-54
learning rate after: 2.000000000000005e-55
Loss: 3.70975923538208
learning rate before: 2.000000000000005e-55
learning rate after: 2.000000000000005e-56
Loss: 3.7369227409362793