-
Notifications
You must be signed in to change notification settings - Fork 290
Closed
Labels
type:bugSomething isn't workingSomething isn't working
Description
contrib.momo.update declares value: Optional[jax.Array] = None, but it never enforces that type. It calls value.astype(state.barf.dtype) unconditionally, so passing any Python float (or NumPy scalar) for the loss raises AttributeError before the optimizer can run.
import jax.numpy as jnp
from optax import contrib
optimizer = contrib.momo(
learning_rate=1.0,
beta=0.0,
lower_bound=0.0,
weight_decay=0.0,
adapt_lower_bound=False,
)
params = [jnp.array([0.0])]
grads = [jnp.array([0.0])]
opt_state = optimizer.init(params)
optimizer.update(grads, opt_state, params, value=0.0)Traceback (most recent call last):
File "/data/src/test.py", line 16, in <module>
optimizer.update(grads, opt_state, params, value=0.0)
File "/home/hdd/miniconda3/envs/py312/lib/python3.12/site-packages/optax/contrib/_momo.py", line 136, in update_fn
barf = bt * state.barf + (1 - bt) * value.astype(state.barf.dtype)
^^^^^^^^^^^^
AttributeError: 'float' object has no attribute 'astype'
Note: This issue was identified by an automated testing tool for academic research and manually verified. If you have any concerns about this type of reporting, please let me know, and I will adjust my workflow accordingly.
Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working