RFC: split checkpoint load/save for huge models
See original GitHub issue🚀 Feature request
While discussing with pytorch devs adding the ability to load/save state_dict on the finer granularity level and not needing to manifest the whole state_dict in memory, we have an additional issue of the model file just being too large. I’d like to propose for transformers
to support multi-part checkpoints.
Reasons for the need:
- the hub limitation: Cloudfront does not support >20GB files so downloads via s3 can’t be fast with those large files
- the current pytorch issue loading the whole state_dict into memory and requiring 2x model size in memory - checkpoint conversion is quite demanding on memory as well for the same reason.
- in general it’s a potential issue for users with imperfect up/down internet connection. uploading/downloading 25GB files is still not easy for all.
Possible solutions:
- as mentioned here, SplitCheckpoint already implements a possible solution which saves each state_dict’s key separately
- as solution 1 but we may save groups of these - e.g. save each layer’s keys together in one pickled state_dict per layer. I looked at some large models and they will have a huge amount of keys, e.g. even t5-small is ~150 keys. But this approach would be more complicated since we now need to define the container block and it’ll be different from model to model. May be by sub-module? So perhaps the first solution is much more simple.
The only addition I’d propose to actually name the files with the full key name rather than obscure files like m18.pt
as implemented by SplitCheckpoint , and which require an extra file to do look ups.
So my proposal is:
config.json
merges.txt
README.md
tokenizer.json
vocab.json
pytorch_model/map.pt
pytorch_model/shared.weight.pt
pytorch_model/encoder.embed_tokens.weight.pt
[...]
pytorch_model/encoder.block.3.layer.0.SelfAttention.v.weight.pt
[...]
pytorch_model/decoder.block.5.layer.1.EncDecAttention.q.weight.pt
[...]
pytorch_model/lm_head.weight
and these are all raw files not belonging to any archive. and map just has the list of keys in their order for when OrderedDict
is important.
the cost of the 1st solution is somewhat slower save/load. I haven’t benchmarked, but the IO will be the bottleneck here, and the ZIP structure currently gets unravelled one tensor at a time anyway, so the difference is likely to be negligible.
other solutions are welcome.
Other examples of split checkpoints:
- Deepspeed’s pipeline (PP) saves each layer as a separate checkpoint, which allows to quickly change the PP degree at run time.
Threshold:
- need to define the threshold at which we automatically switch to this multi-part format unless the user overrides the default. Probably can use the size of the model as the measurement. I think it should be 3B or even less. if model size == 3B the resulting file size are:
- 6GB in fp16/bf16
- 12GB in fp32.
@patrickvonplaten, @patil-suraj, @LysandreJik, @sgugger, @julien-c
Issue Analytics
- State:
- Created 2 years ago
- Reactions:5
- Comments:32 (32 by maintainers)
Plus, we’re trying to stay away from pickle those days 😃
Just posting a small script I used to shard any model: https://gist.github.com/younesbelkada/382016361580b939a87edcddc94c6593 people may want to use it in the future to push sharded models !