Skip to content

Conversation

@NDNM1408
Copy link

This PR adds a Qwen3VLGRPOTrainer,.

Motivation

When I was training Qwen3-VL using GRPOTrainer, I encountered the error IndexError: metadata = video_metadata[index].
While debugging, I found that the original GRPOTrainer class, when given video inputs, produced messages where video fields and fps values appeared inside "text"-type chunks with fps=None and video=None, which caused the bug.
To fix this, I created a new class that normalizes the messages before generation and prevents this error.

What this PR changes

  • Adds a new class: Qwen3VLVideoGRPOTrainer(GRPOTrainer) in grpo_trainer.py.

    • Overrides _generate_single_turn only; all GRPO/DAPO logic (rewards, advantages, clipping, logging, etc.) remains unchanged.
    • Cleans and normalizes the multi-modal conversation structure while preserving:
      • {"type": "video", "video": ..., "fps": ...}
      • {"type": "text", "text": ...}
    • For each conversation in the batch:
      • Calls self.processing_class.apply_chat_template(...) directly on the Qwen3-VL conversation.
      • Sends the resulting inputs to the correct device.
      • Runs model.generate(..., generation_config=self.generation_config).
      • Splits prompt_ids and completion_ids using the prompt length, and returns them in the format expected by the GRPO pipeline.
    • Currently supports only the standard transformers.generate path and explicitly errors if use_vllm=True or use_transformers_paged=True, to keep behavior simple and predictable for this first iteration.
  • Exports Qwen3VLGRPOTrainer:

    • From src/trl/trainers/__init__.py.
    • From src/trl/__init__.py.

This design keeps the change isolated and backward compatible: existing users of GRPOTrainer are unaffected, while Qwen3-VL users can opt into the specialized trainer.

Usage

Example (simplified):

from trl import Qwen3VLVideoGRPOTrainer, GRPOConfig

config = GRPOConfig(
    output_dir="qwen3vl-video-grpo",
    loss_type="dapo",
    use_vllm=False,
    use_transformers_paged=False,
    num_generations=4,
    # other GRPOConfig parameters...
)

trainer = Qwen3VLVideoGRPOTrainer(
    model="Qwen/Qwen3-VL-8B-Instruct",
    reward_funcs=my_reward_fn,       
    args=config,
    train_dataset=my_qwen3vl_video_dataset, 
)

trainer.train()

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for sharing this code. However, we are not going to merge this change for two reasons:

  • Design: trainers are intended to be model-agnostic. Therefore, having a trainer for a single model class goes against this design choice.
  • Support: For now, our priority is to support text and text+image modalities. Video is outside the scope. However, I think it's still useful to know how we can hack this trainer to support video modality. Also, it would be best to have a notebook or script, so the link could be put in our documentation here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants