question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Issue converting Flax model to Pytorch

See original GitHub issue

When using the following script to convert a trained flax model to pytorch, the model seems to perform extremely poorly.

from transformers import RobertaForMaskedLM
model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model.save_pretrained("./")
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM
import numpy as np
import torch 

model_fx = FlaxRobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish")
model_pt = RobertaForMaskedLM.from_pretrained("birgermoell/roberta-swedish", from_flax=True)
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)

logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)

Comparing gives the following input.

tensor([[[  1.7789, -13.5291, -11.2138,  ...,  -5.2875,  -9.3274,  -4.7912],
         [  2.3076, -13.4161, -11.1511,  ...,  -5.3181,  -9.0602,  -4.6083],
         [  2.6451, -13.4425, -11.0671,  ...,  -5.2838,  -8.8323,  -4.2280],
         ...,
         [  1.9009, -13.6516, -11.2348,  ...,  -4.9726,  -9.3278,  -4.6060],
         [  2.0522, -13.5394, -11.2804,  ...,  -4.9960,  -9.1956,  -4.5691],
         [  2.2570, -13.5093, -11.2640,  ...,  -4.9986,  -9.1292,  -4.3310]],

        [[  1.7789, -13.5291, -11.2138,  ...,  -5.2875,  -9.3274,  -4.7912],
         [  2.3076, -13.4161, -11.1511,  ...,  -5.3181,  -9.0602,  -4.6083],
         [  2.6451, -13.4425, -11.0671,  ...,  -5.2838,  -8.8323,  -4.2280],
         ...,
         [  1.9009, -13.6516, -11.2348,  ...,  -4.9726,  -9.3278,  -4.6060],
         [  2.0522, -13.5394, -11.2804,  ...,  -4.9960,  -9.1956,  -4.5691],
         [  2.2570, -13.5093, -11.2640,  ...,  -4.9986,  -9.1292,  -4.3310]]],
       grad_fn=<AddBackward0>)
[[[  0.1418128  -14.170926   -11.12649    ...  -7.542998   -10.79537
    -9.382975  ]
  [  1.7505689  -13.178099   -10.356588   ...  -6.794136   -10.567211
    -8.6670065 ]
  [  2.0270724  -13.522658   -10.372475   ...  -7.0110755  -10.396935
    -8.419178  ]
  ...
  [  0.19080782 -14.390833   -11.399942   ...  -7.469897   -10.715849
    -9.234054  ]
  [  1.3052869  -13.332332   -10.702984   ...  -6.9498534  -10.813769
    -8.608736  ]
  [  1.6442876  -13.226774   -10.59941    ...  -7.0290956  -10.693554
    -8.457008  ]]

 [[  0.1418128  -14.170926   -11.12649    ...  -7.542998   -10.79537
    -9.382975  ]
  [  1.7505689  -13.178099   -10.356588   ...  -6.794136   -10.567211
    -8.6670065 ]
  [  2.0270724  -13.522658   -10.372475   ...  -7.0110755  -10.396935
    -8.419178  ]
  ...
  [  0.19080782 -14.390833   -11.399942   ...  -7.469897   -10.715849
    -9.234054  ]
  [  1.3052869  -13.332332   -10.702984   ...  -6.9498534  -10.813769
    -8.608736  ]
  [  1.6442876  -13.226774   -10.59941    ...  -7.0290956  -10.693554
    -8.457008  ]]]

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:11 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
BirgerMoellcommented, Jul 7, 2021

Awesome. Just to clarify. Once I’m done with training, this script should help me convert the model to pytorch.

from transformers import RobertaForMaskedLM

model = RobertaForMaskedLM.from_pretrained("...", from_flax=True)
model.save_pretrained("./")
0reactions
w11wocommented, May 9, 2022

Hi @jppaolim !

In my case, I loaded the earlier weights of the model (from the first few epochs), instead of the fully-trained model weights from the last training epoch. Loading the right model weights fixed it for me.

Another way to fix it might be training for longer.

Hope this helps! 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

[Flax] Error converting model to PyTorch from Flax #12545
I'd like to know if there's a workaround to this problem. Thanks!
Read more >
Convert PyTorch Models to Flax
Convert PyTorch Models to Flax#. We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm,...
Read more >
Source code for transformers.modeling_flax_utils
[docs]class FlaxPreTrainedModel(ABC): r""" Base class for all models. ... f"Unable to convert pytorch model {archive_file} to Flax deserializable object.
Read more >
[Kaggle JAX Flax]Constant problems | Data Science and ...
I trained the model on jax+flax. i don't know how to convert the model to torch.
Read more >
Writing a Training Loop in JAX + FLAX – Weights & Biases
Instead of a PyTorch · or a Tensorflow ·, Flax has a linen.Module. ; In PyTorch or Keras subclassed models we define all...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found