Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions trl/experimental/online_dpo/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ class may differ from those in [`~transformers.TrainingArguments`].
"after the timeout, a `ConnectionError` is raised.",
},
)
vllm_group_port: int = field(
default=51216,
metadata={
"help": "Port number for the weight update group."
"This is used to communicate with the vllm server. Unless the port is occupied, there is no need to modify it by default."
},
)
vllm_tensor_parallel_size: int = field(
default=1,
metadata={
Expand Down
4 changes: 3 additions & 1 deletion trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,9 @@ def __init__(
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client = VLLMClient(
base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout
)

# Determine device type (supports cuda, xpu, etc.)
accelerator_type = torch.accelerator.current_accelerator().type
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ class GRPOConfig(TrainingArguments):
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_group_port: int = field(
default=51216,
metadata={
"help": "Port number for the weight update group."
"This is used to communicate with the vllm server. Unless the port is occupied, there is no need to modify it by default."
},
)

# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization: float = field(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,9 @@ def cast_outputs_to_original_dtype(module, args, output):
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client = VLLMClient(
base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout
)
self.vllm_client.init_communicator(device=torch.cuda.current_device())

elif self.vllm_mode == "colocate":
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,13 @@ class RLOOConfig(TrainingArguments):
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_group_port: int = field(
default=51216,
metadata={
"help": "Port number for the weight update group."
"This is used to communicate with the vllm server. Unless the port is occupied, there is no need to modify it by default."
},
)

# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization: float = field(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def __init__(
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client = VLLMClient(
base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout
)
self.vllm_client.init_communicator(device=torch.cuda.current_device())

elif self.vllm_mode == "colocate":
Expand Down