Two bugs in AdamW
See original GitHub issueEnvironment info
transformers
version: 4.13.0.dev0- Platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.17
- Python version: 3.9.7
- PyTorch version (GPU?): 1.10.0+cu113 (True)
- Tensorflow version (GPU?): 2.7.0 (False)
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
@thomwolf and @stas00 should be able to help based on git blame
Information
There are two bugs in the implementation of AdamW.
Here’s the current code https://github.com/manuelciosici/transformers/blob/04683c0659aacf31a1e1df8aa2e6cf7b447a6f12/src/transformers/optimization.py#L324-L371
Weight decay bug
Look at lines 369-370. The weight decay is multiplied with p.data
which no longer corresponds to theta_{t-1}
since p.data
was modified in line 369. Below is a picture of Algorithm 2 from the original Adamw paper that shows on line 12 that the weight decay should be multiplied with the previous step’s parameters (i.e., theta_{t-1}
).
From what I can tell, this is a regression since the original AdamW implementation in transformers
applied weight decay properly. Here’s the commit that introduces the bug https://github.com/HuggingFace/transformers/commit/ec07cf5a660926833d6f5208b58730e4af8d1178#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0
For confirmation that weight decay is currently buggy, see the original AdamW implementation, where, on line 74, the weight decay is multiplied with the old parameters as opposed to the new parameters that are calculated on line 71.
Denominator computation bug
The second bug appears in the computation of the denominator corresponding to line 10 in Algorithm 2 above. In the current code (see link in the Information
section), on line 351, the denominator excludes the division by math.sqrt(bias_correction2)
. On line 357, division by math.sqrt(bias_correction2)
appears, but, by this time, eps
has already been added to denom
, making the division not equivalent to line 10 in Algorithm 10.
From what I can tell, this bug was also introduced as part of commit https://github.com/HuggingFace/transformers/commit/ec07cf5a660926833d6f5208b58730e4af8d1178#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0. The previous line update = next_m / (next_v.sqrt() + group['e'])
was correct.
For confirmation that the denominator is not properly calculated, see the original AdamW implementation, where, on line 64 the denominator is computed.
To reproduce
Steps to reproduce the behavior:
- Checkout the branch at https://github.com/manuelciosici/transformers/tree/reveal_broken_adamw:
- Run the unit tests in
tests/test_optimization.py
- Tests
test_compare_adamw_no_weight_decay
andtest_compare_adamw_with_weight_decay
should fail (see the attached failed_tests.txt)
Expected behavior
The two implementations of AdamW should match their parameter updates.
Proposed fix
Checkout the branch at https://github.com/manuelciosici/transformers/tree/fix_adamw . It contains both the unit tests above and a fix for both bugs mentioned above.
I can make a PR once we agree on the two bugs and the fix.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:16 (15 by maintainers)
The NVIDIA engineers have been profiling a few things and torch’s AdamW is faster than ours (apparently apex’s is even faster), so I will add this to the performance docs once I’m able to benchmark this when your PR is ready, @manuelciosici
https://github.com/huggingface/transformers/pull/14708
@stas00 Thank you. I work on this during the weekend.