Sharded T5X checkpoints can't be converted to pytorch ?
See original GitHub issueSystem Info
transformers
version: 4.21.1- Platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.17
- Python version: 3.8.10
- Huggingface_hub version: 0.2.1
- PyTorch version (GPU?): 1.12.1+cu113 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
- Train using t5x to get a checkpoint that’s bigger than 10GB.
- use official conversion script, e.g.,
python transformers/models/t5/convert_t5x_checkpoint_to_flax.py
--t5x_checkpoint_path checkpoint_100000/
--config_name google/t5-v1_1-xxl
--flax_dump_folder_path checkpoint_converted
this yields, e.g.,
$ ls checkpoint_converted
config.json
flax_model-00001-of-00005.msgpack
flax_model-00002-of-00005.msgpack
flax_model-00003-of-00005.msgpack
flax_model-00004-of-00005.msgpack
flax_model-00005-of-00005.msgpack
flax_model.msgpack.index.json
- you can load like this:
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
flax_model = FlaxT5ForConditionalGeneration.from_pretrained("checkpoint_converted/")
but you can’t load like this:
model = T5ForConditionalGeneration.from_pretrained("checkpoint_converted/", from_flax=True)
because of an error:
OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory
Expected behavior
I think it would be nice to be able to load sharded flax checkpoints using the more generic class: this would be useful, e.g., for converting big flax checkpoints to pytorch. The machinery for loading appears to be mostly implemented, but it isn’t yet connected to the T5ForConditionalGeneration.from_pretrained
method. I can work on a PR if all this checks out, but wanted to see if there was something I was missing.
Relevant parts of the code:
checking for the flax sharded file type:
support for sharded pytorch --> flax (but not vice-versa):
function that loads sharded flax checkpoints:
function that might need to be modified to detect a sharded checkpoint, and then call the above:
Issue Analytics
- State:
- Created 10 months ago
- Comments:7 (3 by maintainers)
Top GitHub Comments
update: Increasing the shard size worked for me! But yes, I needed to grab a large RAM instance to do the conversion. Not a bad stopgap. Given that this solved my issue, I might not be able to look at this in the next few weeks, but I’ll leave the issue open for now in case I come back to it
Actually this is relatively important I think since the main way to pretrain T5 is via Google’s T5X repo so having a good Flax->PyTorch conversion in our side I think is not super unimportant.
It’s likely that more T5X checkpoints will come out and it’d be nice to be able to directly convert them to PyTorch.
Maybe gently pinging the author of the conversion script: @stefan-it (in case you have any ideas) and @younesbelkada and @ArthurZucker in case you can find some time to help @jmhessel 😃