[HF Trainer] [new optimizer] add `AnyPrecisionAdamW` (bf16)
See original GitHub issueFeature request
pytorch just merged https://github.com/pytorch/torchdistx/pull/52, which adds AnyPrecisionAdamW
(bf16-support, and future new dtypes)
we should add it to our HF Trainer arsenal
This is open to the community - it shouldn’t be too difficult to add by just checking the existing optimizers. Here are some pointers to start unraveling:
https://github.com/huggingface/transformers/blob/e88e9ff045347c9d92d85806a6987dc7ebcbdd5b/src/transformers/training_args.py#L393-L394 and https://github.com/huggingface/transformers/blob/e88e9ff045347c9d92d85806a6987dc7ebcbdd5b/src/transformers/training_args.py#L94-L106
the key of course is the documentation and tests. checking the existing tests and working from there is what’s needed.
One would start looking at mimicking the integration of other optimizers,
So in this case it’d follow the path of adamw_torch
, as it’s the nearest similar optimizer.
it might help to look at the previous PRs that added new optimizers, e.g. find the PR that added adamw_bnb_8bit
- that could be a good model to copy from. And you can see the scope of work that needs to be done. Except this one should be simpler than adamw_bnb_8bit
as it just plugs in a core pytorch optimizer, that’s why I said adamw_torch
is another good model.
Please remember that this requires pytorch-nightly as this new feature hasn’t made it yet into pytorch-1.13. So you will need to install it from https://pytorch.org/get-started/locally/ (Choose Preview (Nightly))
Thank you!
Issue Analytics
- State:
- Created a year ago
- Comments:12 (8 by maintainers)
Top GitHub Comments
awesome, that would make a good quality test then.
Let’s continue the discussion in the PR https://github.com/huggingface/transformers/pull/18961 so it’s more “actionable” 😃
That’s very helpful, Less - thank you for sharing these use cases and the details!
I will leave to @atturaioe the stage to ask questions as he has been performing the heavy lifting on this task.