Issue converting Flax model to Pytorch
See original GitHub issueWhen using the following script to convert a trained flax model to pytorch, the model seems to perform extremely poorly.
from transformers import RobertaForMaskedLM
model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model.save_pretrained("./")
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM
import numpy as np
import torch
model_fx = FlaxRobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
model_pt = RobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish", from_flax=True)
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)
Comparing gives the following input.
tensor([[[ 1.7789, -13.5291, -11.2138, ..., -5.2875, -9.3274, -4.7912],
[ 2.3076, -13.4161, -11.1511, ..., -5.3181, -9.0602, -4.6083],
[ 2.6451, -13.4425, -11.0671, ..., -5.2838, -8.8323, -4.2280],
...,
[ 1.9009, -13.6516, -11.2348, ..., -4.9726, -9.3278, -4.6060],
[ 2.0522, -13.5394, -11.2804, ..., -4.9960, -9.1956, -4.5691],
[ 2.2570, -13.5093, -11.2640, ..., -4.9986, -9.1292, -4.3310]],
[[ 1.7789, -13.5291, -11.2138, ..., -5.2875, -9.3274, -4.7912],
[ 2.3076, -13.4161, -11.1511, ..., -5.3181, -9.0602, -4.6083],
[ 2.6451, -13.4425, -11.0671, ..., -5.2838, -8.8323, -4.2280],
...,
[ 1.9009, -13.6516, -11.2348, ..., -4.9726, -9.3278, -4.6060],
[ 2.0522, -13.5394, -11.2804, ..., -4.9960, -9.1956, -4.5691],
[ 2.2570, -13.5093, -11.2640, ..., -4.9986, -9.1292, -4.3310]]],
grad_fn=<AddBackward0>)
[[[ 0.1418128 -14.170926 -11.12649 ... -7.542998 -10.79537
-9.382975 ]
[ 1.7505689 -13.178099 -10.356588 ... -6.794136 -10.567211
-8.6670065 ]
[ 2.0270724 -13.522658 -10.372475 ... -7.0110755 -10.396935
-8.419178 ]
...
[ 0.19080782 -14.390833 -11.399942 ... -7.469897 -10.715849
-9.234054 ]
[ 1.3052869 -13.332332 -10.702984 ... -6.9498534 -10.813769
-8.608736 ]
[ 1.6442876 -13.226774 -10.59941 ... -7.0290956 -10.693554
-8.457008 ]]
[[ 0.1418128 -14.170926 -11.12649 ... -7.542998 -10.79537
-9.382975 ]
[ 1.7505689 -13.178099 -10.356588 ... -6.794136 -10.567211
-8.6670065 ]
[ 2.0270724 -13.522658 -10.372475 ... -7.0110755 -10.396935
-8.419178 ]
...
[ 0.19080782 -14.390833 -11.399942 ... -7.469897 -10.715849
-9.234054 ]
[ 1.3052869 -13.332332 -10.702984 ... -6.9498534 -10.813769
-8.608736 ]
[ 1.6442876 -13.226774 -10.59941 ... -7.0290956 -10.693554
-8.457008 ]]]
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (5 by maintainers)
Top Results From Across the Web
[Flax] Error converting model to PyTorch from Flax #12545
I'd like to know if there's a workaround to this problem. Thanks!
Read more >Convert PyTorch Models to Flax
Convert PyTorch Models to Flax#. We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm,...
Read more >Source code for transformers.modeling_flax_utils
[docs]class FlaxPreTrainedModel(ABC): r""" Base class for all models. ... f"Unable to convert pytorch model {archive_file} to Flax deserializable object.
Read more >[Kaggle JAX Flax]Constant problems | Data Science and ...
I trained the model on jax+flax. i don't know how to convert the model to torch.
Read more >Writing a Training Loop in JAX + FLAX – Weights & Biases
Instead of a PyTorch · or a Tensorflow ·, Flax has a linen.Module. ; In PyTorch or Keras subclassed models we define all...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Awesome. Just to clarify. Once I’m done with training, this script should help me convert the model to pytorch.
Hi @jppaolim !
In my case, I loaded the earlier weights of the model (from the first few epochs), instead of the fully-trained model weights from the last training epoch. Loading the right model weights fixed it for me.
Another way to fix it might be training for longer.
Hope this helps! 😃