-
Notifications
You must be signed in to change notification settings - Fork 287
Add a bias_correction_v flag to scale_by_amsgrad to align with the original AMSGrad paper and Pytorch/tensorflow impl #1423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
Wouldn't setting Is the point to not use the bias at all or change the order of the operations as discussed here: pytorch/pytorch#142323 Pytorch applies this the bias in amsgrad to this day: https://github.com/pytorch/pytorch/blob/2164b661219ab0a76aa018e955ba3d8e8f99c083/torch/optim/adam.py#L509 But tensorflow does not (I think): https://github.com/keras-team/keras/blob/f6c4ac55692c132cd16211f4877fac6dbeead749/keras/src/optimizers/adam.py#L130-L150 |
vroulet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change!
Let's just quickly agree on the new argument name and we should merge.
| `None` then the `dtype` is inferred from `params` and `updates`. | ||
| bias_correction_v: Whether to apply bias correction to the second moment | ||
| estimate before taking the elementwise maximum (``nu_max``). Set to | ||
| ``False`` to match the original AMSGrad paper and PyTorch/Keras behavior. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"match pytorch behavior" -> no, see the conversation.
So just say "set to False to match original AMSGrad paper"
| eps=eps, | ||
| eps_root=eps_root, | ||
| mu_dtype=mu_dtype, | ||
| bias_correction_v=bias_correction_v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias_correction_v is not a great name.
bias_correction_nu is already better
debias_nu may even be better but "bias_correction" as a boolean argument is already used in e.g. rmsprop (shame on me for that naming).
@rdyro what do you think?
I think optax original implementation is the one that makes most sense (doing the bias correction after taking the max does not make sense to me). |
vroulet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the paper, neither mu nor nu had bias corrections.
So to fully align with the paper give the option to remove both.
Namely have debias_mu: bool = True, debias_nu: bool = True for example. (or bias_correction_mu: bool = True, bias_correction_nu: bool = True)
Change the code accordingly
Resolves #1389