Skip to content

Commit 782d423

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Enable micro_batch_size for rollout and reference models in PPO learner.
PiperOrigin-RevId: 827589436
1 parent 46ea95e commit 782d423

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tunix/rl/ppo/ppo_learner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _generate_and_compute_advantage(
238238
# "experiences".
239239
completion_output = self.rl_cluster.generate(
240240
prompts=training_input["prompts"],
241+
micro_batch_size=self._rollout_micro_batch_size,
241242
)
242243
completion_ids = completion_output.tokens
243244
prompt_ids = completion_output.left_padded_prompt_tokens
@@ -261,6 +262,7 @@ def _generate_and_compute_advantage(
261262
completion_tokens=completion_ids,
262263
pad_id=pad_value,
263264
eos_id=eos_value,
265+
micro_batch_size=self._compute_logps_micro_batch_size,
264266
)
265267
else:
266268
ref_per_token_logps = None
@@ -272,6 +274,7 @@ def _generate_and_compute_advantage(
272274
old_per_token_logps = self.rl_cluster.get_old_per_token_logps(
273275
prompt_tokens=prompt_ids,
274276
completion_tokens=completion_ids,
277+
micro_batch_size=self._compute_logps_micro_batch_size,
275278
)
276279

277280
# ===== Value computation ======

0 commit comments

Comments
 (0)