Skip to content

Commit 128f21a

Browse files
author
The tunix Authors
committed
Log individual trajetory rewards, rather than average across microbatch
PiperOrigin-RevId: 816502498
1 parent 763fcd5 commit 128f21a

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

tests/rl/grpo/grpo_learner_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def wrapper(*args, **kwargs):
325325
else ('rewards/' + reward_fns.__name__,)
326326
)
327327
for metric_name in [
328-
'rewards/overall',
328+
'rewards/sum',
329329
*rewards_metrics,
330330
'completions/mean_length',
331331
'completions/max_length',
@@ -856,7 +856,7 @@ def test_trajectory_ids(self):
856856
def my_reward_fn(trajectories, prompts, **kwargs):
857857
for t_id, prompt in zip(kwargs['trajectory_ids'], prompts):
858858
trajectories[kwargs['mode']][t_id] = prompt
859-
return 1.0
859+
return [1.0] * len(prompts)
860860

861861
vocab = tc.MockVocab()
862862
model = tc.ToyTransformer(rngs=nnx.Rngs(0), vocab_size=vocab.GetPieceSize())

tunix/rl/rl_learner.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -185,35 +185,38 @@ def _compute_rewards(
185185
f"Content of r: {r}"
186186
)
187187
rewards[:, i] = np.array(r)
188+
for reward in r:
189+
self.rl_cluster.buffer_metrics(
190+
{
191+
f"rewards/{reward_fn.__name__}": (
192+
reward,
193+
np.mean,
194+
),
195+
},
196+
mode=mode,
197+
)
198+
199+
rewards = np.nansum(rewards, axis=1)
200+
for trajectory_idx in range(len(prompts)):
201+
trajectory_rewards = rewards[trajectory_idx]
188202
self.rl_cluster.buffer_metrics(
189203
{
190-
f"rewards/{reward_fn.__name__}": (
191-
np.mean(r),
204+
"rewards/sum": (
205+
np.sum(trajectory_rewards),
192206
np.mean,
193207
),
194208
},
195209
mode=mode,
196210
)
197-
198-
rewards = np.nansum(rewards, axis=1)
199-
self.rl_cluster.buffer_metrics(
200-
{
201-
"rewards/overall": (
202-
np.mean(rewards),
203-
np.mean,
204-
),
205-
},
206-
mode=mode,
207-
)
208-
self.rl_cluster.buffer_metrics(
209-
{
210-
"rewards/min": (
211-
np.min(rewards),
212-
np.min,
213-
),
214-
},
215-
mode=mode,
216-
)
211+
self.rl_cluster.buffer_metrics(
212+
{
213+
"rewards/min": (
214+
np.min(trajectory_rewards),
215+
np.min,
216+
),
217+
},
218+
mode=mode,
219+
)
217220
for p, c in zip(prompts, completions):
218221
self.rl_cluster.buffer_metrics(
219222
{

0 commit comments

Comments
 (0)