question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

AdamW in HuggingFace is different from AdamW in Pytorch

See original GitHub issue

❓ Question

I just noticed that the implementation of AdamW in HuggingFace is different from PyTorch. The previous AdamW first updates the gradient then apply the weight decay. However, in the paper (Decoupled Weight Decay Regularization, link: https://arxiv.org/abs/1711.05101) and the implementation of Pytorch, the AdamW first apply the weight decay then update the gradient.

I was wondering if the two approaches are the same. Thanks! (In my opinion, they are not the same procedure.)

HuggingFace:

for group in self.param_groups:
    for p in group["params"]:
        ...
        # Decay the first and second moment running average coefficient
        # In-place operations to update the averages at the same time
        exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
        denom = exp_avg_sq.sqrt().add_(group["eps"])

        step_size = group["lr"]
        if group["correct_bias"]:  # No bias correction for Bert
            bias_correction1 = 1.0 - beta1 ** state["step"]
            bias_correction2 = 1.0 - beta2 ** state["step"]
            step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

        p.data.addcdiv_(-step_size, exp_avg, denom)

        # Just adding the square of the weights to the loss function is *not*
        # the correct way of using L2 regularization/weight decay with Adam,
        # since that will interact with the m and v parameters in strange ways.
        #
        # Instead we want to decay the weights in a manner that doesn't interact
        # with the m/v parameters. This is equivalent to adding the square
        # of the weights to the loss with plain (non-momentum) SGD.
        # Add weight decay at the end (fixed version)
        if group["weight_decay"] > 0.0:
            p.data.add_(-group["lr"] * group["weight_decay"], p.data)

Pytorch:

for group in self.param_groups:
    for p in group['params']:
        ...
        # Perform stepweight decay
        p.data.mul_(1 - group['lr'] * group['weight_decay'])

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
        if amsgrad:
            max_exp_avg_sq = state['max_exp_avg_sq']
        beta1, beta2 = group['betas']

        state['step'] += 1
        bias_correction1 = 1 - beta1 ** state['step']
        bias_correction2 = 1 - beta2 ** state['step']

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(1 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

        step_size = group['lr'] / bias_correction1

        p.data.addcdiv_(-step_size, exp_avg, denom)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:7 (1 by maintainers)

github_iconTop GitHub Comments

4reactions
hhaoyancommented, Mar 16, 2021

Update: they are indeed the same. PyTorch’s implementation is just too confusing to understand.

1reaction
yhCyancommented, Jul 30, 2020

I find this question too. Two codes are obviously different. Because the p.data in huggingface has changed through p.data.addcdiv_(-step_size, exp_avg, denom) But I can’t understand why.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Optimization - Hugging Face
AdamW (PyTorch) ... Implements Adam algorithm with weight decay fix as introduced in Decoupled Weight Decay ... Schedules. Learning Rate Schedules (Pytorch) ...
Read more >
Migrating from pytorch-pretrained-bert - Hugging Face
The two optimizers previously included, BertAdam and OpenAIAdam , have been replaced by a single AdamW optimizer which has a few differences:.
Read more >
Migrating from pytorch-pretrained-bert - Hugging Face
The new optimizer AdamW matches PyTorch Adam optimizer API. The schedules are now standard PyTorch learning rate schedulers and not part of the...
Read more >
Huggingface transformers longformer optimizer warning AdamW
I tried another transformer such as distilbert-base-uncased using the identical code but it seems to run without any warnings. Is this warning ...
Read more >
Optimization — transformers 3.0.2 documentation
AdamW (PyTorch)¶. class transformers. AdamW (params: Iterable[torch.nn.parameter.Parameter], lr: float = 0.001, betas: Tuple[float, float] = 0.9, 0.999, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found