Hello,
We noticed that models trained using the AMSGrad optimizer in Optax tend to yield slightly poorer results compared to the same models trained using PyTorch. The two implementations differ as follows:
Current Optax implementation:

PyTorch implementation (the implementation in TensorFlow follows the same algorithm):

Would it be possible to align the Optax implementation with the PyTorch/TensorFlow version? This would improve consistency across the different ML frameworks and possibly improve performance.