Replies: 1 comment
-
|
Hi @JavierJJR, Great question! The Reduce-On-Plateau scheduler works the same way in Flax NNX as it does in Linen - the key difference is just how you manage the optimizer state in your training loop. Here's a practical example: Basic Setupimport optax
from optax.contrib import reduce_on_plateau
import jax
import jax.numpy as jnp
from flax import nnx
# Create your NNX model
model = nnx.Linear(10, 1, rngs=nnx.Rngs(0))
# Create the reduce-on-plateau scheduler
schedule = reduce_on_plateau(
init_value=1e-3,
factor=0.5,
patience=10,
cooldown=5,
min_value=1e-6
)
# Create optimizer with the schedule
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=schedule))Training Loop with Plateau Detection# Initialize the plateau state
plateau_state = schedule.init(jnp.array(0.0))
best_loss = float('inf')
patience_counter = 0
for epoch in range(num_epochs):
# Your training step
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
# Validation
val_loss = validate(model, val_data)
# Update plateau state based on validation loss
# The key is to call the schedule with the metric you're tracking
updates, plateau_state = schedule.update(
updates=None, # Not used for this scheduler
state=plateau_state,
value=val_loss # Pass your validation metric here
)
# The learning rate is automatically adjusted in the optimizer
# You can check current LR like this:
current_lr = optimizer.opt_state[0].hyperparams['learning_rate']
print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}")Complete Working Exampleimport optax
from optax.contrib import reduce_on_plateau
import jax
import jax.numpy as jnp
from flax import nnx
# Simple model
class MLP(nnx.Module):
def __init__(self, rngs):
self.linear1 = nnx.Linear(10, 64, rngs=rngs)
self.linear2 = nnx.Linear(64, 1, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.linear1(x))
return self.linear2(x)
# Initialize
model = MLP(rngs=nnx.Rngs(0))
# Create scheduler
schedule = reduce_on_plateau(
init_value=1e-2,
factor=0.5,
patience=5
)
# Create optimizer
tx = optax.adam(learning_rate=schedule)
optimizer = nnx.Optimizer(model, tx)
# Initialize plateau state
plateau_state = schedule.init(jnp.array(0.0))
# Training loop
def loss_fn(model, x, y):
pred = model(x)
return jnp.mean((pred - y) ** 2)
for step in range(100):
# Dummy data
x = jax.random.normal(jax.random.PRNGKey(step), (32, 10))
y = jax.random.normal(jax.random.PRNGKey(step + 1000), (32, 1))
# Compute loss and gradients
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
# Update model
optimizer.update(grads)
# Every N steps, update plateau scheduler
if step % 10 == 0:
val_loss = loss # In practice, use actual validation loss
_, plateau_state = schedule.update(
updates=None,
state=plateau_state,
value=val_loss
)Key Points
Hope this helps! Let me know if you need clarification on any part. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello all,
I am quite new to Optax and Flax. To learn how to use Optax, I have begun implementing models that I had previously built in PyTorch into Flax NNX. However, I'm encountering significant difficulties when implementing the Reduce-On-Plateau Learning Rate Scheduler. At https://optax.readthedocs.io/en/latest/_collections/examples/contrib/reduce_on_plateau.html, an example is provided for using it with Flax Linen, but I haven't found any examples for Flax NNX. For this reason, I was wondering if anyone has a simple example of how to use this scheduler in Flax NNX.
Thanks a lot in advance! 🙏
Beta Was this translation helpful? Give feedback.
All reactions