question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Sharded T5X checkpoints can't be converted to pytorch ?

See original GitHub issue

System Info

  • transformers version: 4.21.1
  • Platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.17
  • Python version: 3.8.10
  • Huggingface_hub version: 0.2.1
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@patrickvonplaten

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

  1. Train using t5x to get a checkpoint that’s bigger than 10GB.
  2. use official conversion script, e.g.,
python transformers/models/t5/convert_t5x_checkpoint_to_flax.py
   --t5x_checkpoint_path checkpoint_100000/
   --config_name google/t5-v1_1-xxl
   --flax_dump_folder_path checkpoint_converted

this yields, e.g.,

$ ls checkpoint_converted
config.json
flax_model-00001-of-00005.msgpack
flax_model-00002-of-00005.msgpack
flax_model-00003-of-00005.msgpack
flax_model-00004-of-00005.msgpack
flax_model-00005-of-00005.msgpack
flax_model.msgpack.index.json
  1. you can load like this:
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
flax_model = FlaxT5ForConditionalGeneration.from_pretrained("checkpoint_converted/")

but you can’t load like this:

model = T5ForConditionalGeneration.from_pretrained("checkpoint_converted/", from_flax=True)

because of an error:

OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory

Expected behavior

I think it would be nice to be able to load sharded flax checkpoints using the more generic class: this would be useful, e.g., for converting big flax checkpoints to pytorch. The machinery for loading appears to be mostly implemented, but it isn’t yet connected to the T5ForConditionalGeneration.from_pretrained method. I can work on a PR if all this checks out, but wanted to see if there was something I was missing.

Relevant parts of the code:

checking for the flax sharded file type:

https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/modeling_utils.py#L1997-L2029

https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/modeling_flax_utils.py#L659-L685

support for sharded pytorch --> flax (but not vice-versa):

https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py

function that loads sharded flax checkpoints:

https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_utils.py#L424-L468

function that might need to be modified to detect a sharded checkpoint, and then call the above:

https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L239

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
jmhesselcommented, Nov 18, 2022

update: Increasing the shard size worked for me! But yes, I needed to grab a large RAM instance to do the conversion. Not a bad stopgap. Given that this solved my issue, I might not be able to look at this in the next few weeks, but I’ll leave the issue open for now in case I come back to it

1reaction
patrickvonplatencommented, Nov 17, 2022

Actually this is relatively important I think since the main way to pretrain T5 is via Google’s T5X repo so having a good Flax->PyTorch conversion in our side I think is not super unimportant.

It’s likely that more T5X checkpoints will come out and it’d be nice to be able to directly convert them to PyTorch.

Maybe gently pinging the author of the conversion script: @stefan-it (in case you have any ideas) and @younesbelkada and @ArthurZucker in case you can find some time to help @jmhessel 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Sharded checkpoints - Accelerate - Hugging Face Forums
I just edited my config for a t5 and loaded it up and saved it. This gave the sharded checkpoints. But this memory...
Read more >
Checkpointing with flax.training.checkpoints - Read the Docs
With Flax Checkpoints, you can save and load model parameters, metadata, and a variety of Python data. In addition, it provides basic features ......
Read more >
Amazon SageMaker – AWS Machine Learning Blog
We want your AMT journey to be hands-on and practical, so we have shared the ... default model in PyTorch and converting it...
Read more >
Error in converting tensorflow BERT checkpoints to pytorch
When relative path is specified, the model is unable to find the corresponding files. – Mr. NLP. Dec 4, 2019 at 6:03. Add ......
Read more >
Copyright by Jayashree Mohan 2021 - Microsoft
It wouldn't be an exaggeration to say that my research career changed ... CrashMonkey, I am sorry you couldn't be a part of...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found