[trainer] new in pytorch: `torch.optim._multi_tensor` faster optimizers
See original GitHub issueBack in September pytorch introduced torch.optim._multi_tensor
https://github.com/pytorch/pytorch/pull/43507 which should be much more efficient for situations with lots of small feature tensors (transformers
) and thus should show an appreciable speed up in training. If someone is interested in the progress of this project here is the stack to track: https://github.com/pytorch/pytorch/pull/48223
This feature is currently an alpha stage, so users can try to use it by simply replacing torch.optim
with torch.optim._multi_tensor
in HF Trainer or their own trainer.
Eventually it’ll replace torch.optim
so there is nothing that we need to do otherwise.
@blefaudeux who alerted me to this improvement suggested it should have good speed ups for the DDP/Sharded DDP training.
If resources allow it’d be good to run some benchmarks. Please feel free to beat me to it.
Thanks to @blefaudeux for the heads up, and @izdeby for working on this enhancement and clarifying where things are at.
heads up to: @sgugger, @patrickvonplaten - nothing else that needs to be done.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:5
- Comments:7 (5 by maintainers)
Yes, I was just about to revisit it.
edit: I thought you might have wanted to work on that, but the pytorch team asks to run a profiler on it and all, so I probably will look into testing it out again.
— original comment —
Do you want to take a lead on this experiment, @jaketae?
The new
--optim
HF Trainer just got merged, so you can quickly implement--optim adamw_torch_multi_tensor
in the same way--optim adamw
You can use this tool for benchmarking https://github.com/huggingface/transformers/pull/14934 if it helps. I think it’s pretty stable now, I will propose to PR it.
you must have a really strange bottleneck in that test, neither the latest fairscale nor these are changing anything ? These optimizers are measurably faster in isolation, and sure enough we see a difference in fairscale CI, even on a dummy job / small model (see for instance, two last jobs)