1

Lately I stumbled upon the Llama3 hf docs: https://huggingface.co/docs/transformers/main/en/model_doc/llama3

The Llama3 models were trained using bfloat16, but the original inference uses float16. The checkpoints uploaded on the Hub use torch_dtype = 'float16', which will be used by the AutoModel API to cast the checkpoints from torch.float32 to torch.float16.Training the model in float16 is not recommended and is known to produce nan; as such, the model should be trained in bfloat16.

Does this mean that the weights are always stored in float16 during training, and each forward pass we cast them to bfloat16, and the gradients as well as the activations are stored in bfloat16?

If the things mentioned above are true, then what about LoRA? So should the compute finetuning dtype also be bfloat16? While weights are stored in float16 should inference be performed in float16 as it will allow for more precise computations?

vanerk
  • 26
  • 1

1 Answers1

1

Let us know the difference between bfloat16 and float16. https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/

As per the above blog, bfloat16 has more exponent bits than float16, which means it can store larger values but has less precision. That means it won't be getting overflowing during training quickly compared to float16.

I think when defining weight, 's type will be the same throughout training unless converted. We will require conversion time to convert float16 to bfloat16 if you have weights of different types.

If your GPU supports bfloat16 then use it in fine-tuning, otherwise keep the dtype same.