question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Flax] Torch fp16 model weights not upcast when loaded in Flax

See original GitHub issue

In some scenarios, one may want to load a Flax model directly from pre-trained PyTorch model weights. In this process, the original dtype of the PyTorch model weights is maintained when loaded into Flax. For models such as bart-large, which has it’s PyTorch weights stored in fp16 on the Hub, this can result in a Flax model with weights in an undesirable dtype. This is highlighted by the following code snippet, which first loads a FlaxSpeechEncoderDecoderModel from entirely fp32 PyTorch weights, and then again from fp32 encoder weights and fp16 decoder weights:

from transformers import FlaxSpeechEncoderDecoderModel

# fp32 PyTorch weights
encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'
decoder_id = 'hf-internal-testing/tiny-random-bart'

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
print("-----------From fp32 PyTorch weights-----------")
print(f"Encoder dtype: {model.params['encoder']['masked_spec_embed'].dtype}")
print(f"Decoder dtype: {model.params['decoder']['model']['decoder']['embed_tokens']['embedding'].dtype}")

# same decoder as previously, but with weights downcasted to fp16
decoder_id = 'sanchit-gandhi/tiny-random-bart-fp16'

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
print("---------From fp32/fp16 PyTorch weights---------")
print(f"Encoder dtype: {model.params['encoder']['masked_spec_embed'].dtype}")
print(f"Decoder dtype: {model.params['decoder']['model']['decoder']['embed_tokens']['embedding'].dtype}")

Output:

-----------From fp32 PyTorch weights-----------
Encoder dtype: float32
Decoder dtype: float32

---------From fp32/fp16 PyTorch weights---------
Encoder dtype: float32
Decoder dtype: float16

Having a model stored in two different dtype raises issues with training - Optax optimisers expect the model to maintain one uniform dtype. Furthermore, the default assumption is that all Flax model weights are in fp32.

This weight conversion is handled by the general conversion script: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py. Would it be wise to inform the user of the potentially erroneous model dtype in this scenario? If informed, they could then call the to_fp32 method from modeling_flax_utils to upcast the weights to fp32: https://github.com/huggingface/transformers/blob/a9604067225219e132abdff2793f78ead798453b/src/transformers/modeling_flax_utils.py#L231

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
sanchit-gandhicommented, Apr 14, 2022

The user warning for the Flax .from_pretrained method was implemented in #16762. As an extreme edge case and following an extensive offline discussion, it was decided that the fp16 PyTorch weights for bart-large will remain as is. The original checkpoint has been reconverted and uploaded it in fp32 to another repo for those wishing to explicitly use full-precision weights: https://huggingface.co/patrickvonplaten/bart-large-fp32 Note that the fp16 weights should not be an issue for any PyTorch models: the PyTorch .from_pretrained method automatically upcasts model weights to fp32.

0reactions
patil-surajcommented, Apr 13, 2022

Agree with @sanchit-gandhi here. I’m in favour of adding a warning and letting the user know that weights are not fp32.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How To Fit a Bigger Model and Train It Faster - Hugging Face
This section gives brief ideas on how to make training faster and support bigger models. Later sections will expand, demonstrate and elucidate each...
Read more >
Mixed precision training - fastai
This is why we keep a copy of the weights in FP32 (called master model). Then, our training loop will look like: compute...
Read more >
N-Bit Precision (Intermediate) - PyTorch Lightning
It combines FP32 and lower-bit floating-points (such as FP16) to reduce memory footprint and increase performance during model training and evaluation. It ...
Read more >
[Tip] TorchScript Supports Half Precision | by Ceshine Lee
This is a short post describing how to use half precision in TorchScript. This can speed up models that were trained using mixed...
Read more >
Numerical accuracy — PyTorch 1.13 documentation
Because of this, PyTorch is not guaranteed to produce bitwise identical ... When training using FP16 precision, some models may fail to converge...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found