[Flax] Torch fp16 model weights not upcast when loaded in Flax
See original GitHub issueIn some scenarios, one may want to load a Flax model directly from pre-trained PyTorch model weights. In this process, the original dtype of the PyTorch model weights is maintained when loaded into Flax. For models such as bart-large, which has it’s PyTorch weights stored in fp16 on the Hub, this can result in a Flax model with weights in an undesirable dtype. This is highlighted by the following code snippet, which first loads a FlaxSpeechEncoderDecoderModel from entirely fp32 PyTorch weights, and then again from fp32 encoder weights and fp16 decoder weights:
from transformers import FlaxSpeechEncoderDecoderModel
# fp32 PyTorch weights
encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'
decoder_id = 'hf-internal-testing/tiny-random-bart'
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
print("-----------From fp32 PyTorch weights-----------")
print(f"Encoder dtype: {model.params['encoder']['masked_spec_embed'].dtype}")
print(f"Decoder dtype: {model.params['decoder']['model']['decoder']['embed_tokens']['embedding'].dtype}")
# same decoder as previously, but with weights downcasted to fp16
decoder_id = 'sanchit-gandhi/tiny-random-bart-fp16'
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
print("---------From fp32/fp16 PyTorch weights---------")
print(f"Encoder dtype: {model.params['encoder']['masked_spec_embed'].dtype}")
print(f"Decoder dtype: {model.params['decoder']['model']['decoder']['embed_tokens']['embedding'].dtype}")
Output:
-----------From fp32 PyTorch weights-----------
Encoder dtype: float32
Decoder dtype: float32
---------From fp32/fp16 PyTorch weights---------
Encoder dtype: float32
Decoder dtype: float16
Having a model stored in two different dtype raises issues with training - Optax optimisers expect the model to maintain one uniform dtype. Furthermore, the default assumption is that all Flax model weights are in fp32.
This weight conversion is handled by the general conversion script: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py. Would it be wise to inform the user of the potentially erroneous model dtype in this scenario? If informed, they could then call the to_fp32
method from modeling_flax_utils
to upcast the weights to fp32:
https://github.com/huggingface/transformers/blob/a9604067225219e132abdff2793f78ead798453b/src/transformers/modeling_flax_utils.py#L231
Issue Analytics
- State:
- Created a year ago
- Comments:9 (9 by maintainers)
The user warning for the Flax
.from_pretrained
method was implemented in #16762. As an extreme edge case and following an extensive offline discussion, it was decided that the fp16 PyTorch weights for bart-large will remain as is. The original checkpoint has been reconverted and uploaded it in fp32 to another repo for those wishing to explicitly use full-precision weights: https://huggingface.co/patrickvonplaten/bart-large-fp32 Note that the fp16 weights should not be an issue for any PyTorch models: the PyTorch.from_pretrained
method automatically upcasts model weights to fp32.Agree with @sanchit-gandhi here. I’m in favour of adding a warning and letting the user know that weights are not
fp32
.