Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class GRPOConfig(TrainingArguments):
parameter is only effective when `use_vllm` is set to `False`.
cache_implementation (`str`, *optional*):
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether to skip special tokens when decoding completions. This affects both reward computation and logging.

> Parameters that control generation acceleration powered by vLLM

Expand Down Expand Up @@ -451,6 +453,13 @@ class GRPOConfig(TrainingArguments):
default=None,
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
)
skip_special_tokens: bool = field(
default=True,
metadata={
"help": "Whether to skip special tokens when decoding completions. This affects both reward computation "
"and logging."
},
)

# Parameters that control generation acceleration powered by vLLM
use_vllm: bool = field(
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def cast_outputs_to_original_dtype(module, args, output):
self.log_completions = args.log_completions
self.log_unique_prompts = args.log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
self.skip_special_tokens = args.skip_special_tokens
# Keep logs sized to the generation batch to record only outputs from the latest model update.
self._logs = {
"images": deque(maxlen=args.generation_batch_size),
Expand Down Expand Up @@ -1569,8 +1570,8 @@ def _generate_and_score_completions(
ref_per_token_logps = None

# Decode
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=self.skip_special_tokens)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=self.skip_special_tokens)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text, strict=True):
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class RLOOConfig(TrainingArguments):
parameter is only effective when `use_vllm` is set to `False`.
cache_implementation (`str`, *optional*):
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether to skip special tokens when decoding completions. This affects both reward computation and logging.

> Parameters that control generation acceleration powered by vLLM

Expand Down Expand Up @@ -377,6 +379,13 @@ class RLOOConfig(TrainingArguments):
default=None,
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
)
skip_special_tokens: bool = field(
default=True,
metadata={
"help": "Whether to skip special tokens when decoding completions. This affects both reward computation "
"and logging."
},
)

# Parameters that control generation acceleration powered by vLLM
use_vllm: bool = field(
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def __init__(
self.log_completions = args.log_completions
self.log_unique_prompts = args.log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
self.skip_special_tokens = args.skip_special_tokens
# Keep logs sized to the generation batch to record only outputs from the latest model update.
self._logs = {
"images": deque(maxlen=args.generation_batch_size),
Expand Down Expand Up @@ -1336,8 +1337,8 @@ def _generate_and_score_completions(
ref_per_token_logps = None

# Decode
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=self.skip_special_tokens)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=self.skip_special_tokens)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text, strict=True):
Expand Down