Skip to content

AMSGrad implementation differs from PyTorch/TensorFlow #1389

@sklenard

Description

@sklenard

Hello,

We noticed that models trained using the AMSGrad optimizer in Optax tend to yield slightly poorer results compared to the same models trained using PyTorch. The two implementations differ as follows:

Current Optax implementation:
equation

PyTorch implementation (the implementation in TensorFlow follows the same algorithm):
equation

Would it be possible to align the Optax implementation with the PyTorch/TensorFlow version? This would improve consistency across the different ML frameworks and possibly improve performance.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions