Skip to content

contrib.momo crashes when loss value is a Python float #1500

@T90REAL

Description

@T90REAL

contrib.momo.update declares value: Optional[jax.Array] = None, but it never enforces that type. It calls value.astype(state.barf.dtype) unconditionally, so passing any Python float (or NumPy scalar) for the loss raises AttributeError before the optimizer can run.

import jax.numpy as jnp
from optax import contrib

optimizer = contrib.momo(
    learning_rate=1.0,
    beta=0.0,
    lower_bound=0.0,
    weight_decay=0.0,
    adapt_lower_bound=False,
)

params = [jnp.array([0.0])]
grads = [jnp.array([0.0])]
opt_state = optimizer.init(params)

optimizer.update(grads, opt_state, params, value=0.0)
Traceback (most recent call last):
  File "/data/src/test.py", line 16, in <module>
    optimizer.update(grads, opt_state, params, value=0.0)
  File "/home/hdd/miniconda3/envs/py312/lib/python3.12/site-packages/optax/contrib/_momo.py", line 136, in update_fn
    barf = bt * state.barf + (1 - bt) * value.astype(state.barf.dtype)
                                        ^^^^^^^^^^^^
AttributeError: 'float' object has no attribute 'astype'

Note: This issue was identified by an automated testing tool for academic research and manually verified. If you have any concerns about this type of reporting, please let me know, and I will adjust my workflow accordingly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions