question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[HF Trainer] [new optimizer] add `AnyPrecisionAdamW` (bf16)

See original GitHub issue

Feature request

pytorch just merged https://github.com/pytorch/torchdistx/pull/52, which adds AnyPrecisionAdamW (bf16-support, and future new dtypes)

we should add it to our HF Trainer arsenal

This is open to the community - it shouldn’t be too difficult to add by just checking the existing optimizers. Here are some pointers to start unraveling:

https://github.com/huggingface/transformers/blob/e88e9ff045347c9d92d85806a6987dc7ebcbdd5b/src/transformers/training_args.py#L393-L394 and https://github.com/huggingface/transformers/blob/e88e9ff045347c9d92d85806a6987dc7ebcbdd5b/src/transformers/training_args.py#L94-L106

the key of course is the documentation and tests. checking the existing tests and working from there is what’s needed.

One would start looking at mimicking the integration of other optimizers,

So in this case it’d follow the path of adamw_torch , as it’s the nearest similar optimizer.

it might help to look at the previous PRs that added new optimizers, e.g. find the PR that added adamw_bnb_8bit - that could be a good model to copy from. And you can see the scope of work that needs to be done. Except this one should be simpler than adamw_bnb_8bit as it just plugs in a core pytorch optimizer, that’s why I said adamw_torch is another good model.

Please remember that this requires pytorch-nightly as this new feature hasn’t made it yet into pytorch-1.13. So you will need to install it from https://pytorch.org/get-started/locally/ (Choose Preview (Nightly))

Thank you!

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:12 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
stas00commented, Sep 12, 2022

setting momentum_dtype and variance_dtype to torch.float32 and use_kahan_summation=False, brings AnyPrecision to the traditional AdamW optimizer so you can quickly compare using BF16, pure or variance only, for your training.

awesome, that would make a good quality test then.

Let’s continue the discussion in the PR https://github.com/huggingface/transformers/pull/18961 so it’s more “actionable” 😃

1reaction
stas00commented, Sep 12, 2022

That’s very helpful, Less - thank you for sharing these use cases and the details!

I will leave to @atturaioe the stage to ask questions as he has been performing the heavy lifting on this task.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Trainer - Hugging Face
Trainer. The Trainer class provides an API for feature-complete training in PyTorch for most standard use cases. It's used in most of the...
Read more >
Contribute to huggingface/transformers · GitHub
[HF Trainer] [new optimizer] add `AnyPrecisionAdamW` (bf16) Good First Issue Good Second Issue Issues that are more difficult to do than "Good First"...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found