File tree Expand file tree Collapse file tree 2 files changed +27
-24
lines changed Expand file tree Collapse file tree 2 files changed +27
-24
lines changed Original file line number Diff line number Diff line change @@ -325,7 +325,7 @@ def wrapper(*args, **kwargs):
325325 else ('rewards/' + reward_fns .__name__ ,)
326326 )
327327 for metric_name in [
328- 'rewards/overall ' ,
328+ 'rewards/sum ' ,
329329 * rewards_metrics ,
330330 'completions/mean_length' ,
331331 'completions/max_length' ,
@@ -856,7 +856,7 @@ def test_trajectory_ids(self):
856856 def my_reward_fn (trajectories , prompts , ** kwargs ):
857857 for t_id , prompt in zip (kwargs ['trajectory_ids' ], prompts ):
858858 trajectories [kwargs ['mode' ]][t_id ] = prompt
859- return 1.0
859+ return [ 1.0 ] * len ( prompts )
860860
861861 vocab = tc .MockVocab ()
862862 model = tc .ToyTransformer (rngs = nnx .Rngs (0 ), vocab_size = vocab .GetPieceSize ())
Original file line number Diff line number Diff line change @@ -185,35 +185,38 @@ def _compute_rewards(
185185 f"Content of r: { r } "
186186 )
187187 rewards [:, i ] = np .array (r )
188+ for reward in r :
189+ self .rl_cluster .buffer_metrics (
190+ {
191+ f"rewards/{ reward_fn .__name__ } " : (
192+ reward ,
193+ np .mean ,
194+ ),
195+ },
196+ mode = mode ,
197+ )
198+
199+ rewards = np .nansum (rewards , axis = 1 )
200+ for trajectory_idx in range (len (prompts )):
201+ trajectory_rewards = rewards [trajectory_idx ]
188202 self .rl_cluster .buffer_metrics (
189203 {
190- f "rewards/{ reward_fn . __name__ } " : (
191- np .mean ( r ),
204+ "rewards/sum " : (
205+ np .sum ( trajectory_rewards ),
192206 np .mean ,
193207 ),
194208 },
195209 mode = mode ,
196210 )
197-
198- rewards = np .nansum (rewards , axis = 1 )
199- self .rl_cluster .buffer_metrics (
200- {
201- "rewards/overall" : (
202- np .mean (rewards ),
203- np .mean ,
204- ),
205- },
206- mode = mode ,
207- )
208- self .rl_cluster .buffer_metrics (
209- {
210- "rewards/min" : (
211- np .min (rewards ),
212- np .min ,
213- ),
214- },
215- mode = mode ,
216- )
211+ self .rl_cluster .buffer_metrics (
212+ {
213+ "rewards/min" : (
214+ np .min (trajectory_rewards ),
215+ np .min ,
216+ ),
217+ },
218+ mode = mode ,
219+ )
217220 for p , c in zip (prompts , completions ):
218221 self .rl_cluster .buffer_metrics (
219222 {
You can’t perform that action at this time.
0 commit comments