Skip to content

Commit 8f75b23

Browse files
committed
Move rollout related configs from cluster config to rollout_config.
1 parent 504b6a9 commit 8f75b23

File tree

5 files changed

+73
-74
lines changed

5 files changed

+73
-74
lines changed

scripts/grpo_demo_llama3_qwen2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -796,11 +796,12 @@ def evaluate(
796796
temperature=TEMPERATURE,
797797
top_p=TOP_P,
798798
top_k=TOP_K,
799+
rollout_vllm_model_version=VLLM_MODEL_VERSION,
800+
rollout_vllm_hbm_utilization=0.2,
801+
rollout_vllm_tpu_backend_type="jax",
802+
rollout_vllm_server_mode=args.rollout_server_mode,
799803
),
800-
rollout_vllm_model_version=VLLM_MODEL_VERSION,
801-
rollout_vllm_hbm_utilization=0.2,
802-
rollout_vllm_tpu_backend_type="jax",
803-
rollout_vllm_server_mode=args.rollout_server_mode,
804+
804805
)
805806

806807
grpo_config = grpo_learner.GRPOConfig(

tunix/rl/rl_cluster.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from jax.typing import ArrayLike # pylint: disable=g-importing-member
3636
import jaxtyping
3737
import optax
38-
from tunix.generate import mappings
3938
# Internal placeholder for sglang_jax rollout worker stub, don't change this line.
4039
# Internal placeholder for vllm rollout worker stub, don't change this line.
4140
from tunix.rl import reshard
@@ -181,22 +180,7 @@ class ClusterConfig:
181180
rollout_config: (
182181
dict[Mode, base_rollout.RolloutConfig] | base_rollout.RolloutConfig
183182
)
184-
rollout_mapping_config: mappings.MappingConfig | None = None
185183

186-
rollout_vllm_server_mode: bool = False
187-
rollout_vllm_model_version: str = ""
188-
rollout_vllm_lora_config: dict[str, Any] | None = None
189-
rollout_vllm_hbm_utilization: float = 0.2
190-
rollout_vllm_init_with_random_weights: bool = True
191-
rollout_vllm_tpu_backend_type: str | None = None
192-
rollout_vllm_swap_space_size_gb: float = 4.0 # in GiB
193-
194-
rollout_sglang_jax_model_version: str = ""
195-
rollout_sglang_jax_context_length: int = 8192
196-
rollout_sglang_jax_mem_fraction_static: float = 0.2
197-
rollout_sglang_jax_init_with_random_weights: bool = True
198-
rollout_sglang_jax_disable_radix_cache: bool = True
199-
rollout_sglang_jax_enable_deterministic_sampling: bool = False
200184

201185

202186
class RLCluster:
@@ -403,29 +387,16 @@ def _init_cluster(self):
403387
elif self.cluster_config.rollout_engine == "vllm":
404388
from tunix.rl.rollout import vllm_rollout
405389

406-
if self.cluster_config.rollout_vllm_model_version is None:
390+
if self.cluster_config.rollout_config.rollout_vllm_model_version is None:
407391
raise ValueError("Rollout vllm model version or path is missing!")
408392

409-
backend = (
410-
self.cluster_config.rollout_engine
411-
+ "_"
412-
+ self.cluster_config.rollout_vllm_tpu_backend_type
413-
)
414393
# TODO(linchai): maybe support offloading for vllm rollout.
415394
self._rollout = vllm_rollout.VllmRollout(
416395
self.rollout_actor,
417396
self.tokenizer,
418397
cache_config_or_size=max_kv_cache_size,
419398
mesh=self.r2m[Role.ROLLOUT],
420-
model_version=self.cluster_config.rollout_vllm_model_version,
421-
hbm_utilization=self.cluster_config.rollout_vllm_hbm_utilization,
422-
init_with_random_weights=self.cluster_config.rollout_vllm_init_with_random_weights,
423-
tpu_backend_type=self.cluster_config.rollout_vllm_tpu_backend_type,
424-
swap_space=self.cluster_config.rollout_vllm_swap_space_size_gb,
425-
lora_config=self.cluster_config.rollout_vllm_lora_config,
426-
rollout_engine=backend,
427-
mapping_config=self.cluster_config.rollout_mapping_config,
428-
server_mode=self.cluster_config.rollout_vllm_server_mode,
399+
rollout_config=self.cluster_config.rollout_config,
429400
)
430401
elif self.cluster_config.rollout_engine == "sglang_jax":
431402
from tunix.rl.rollout import sglang_jax_rollout
@@ -434,13 +405,7 @@ def _init_cluster(self):
434405
self.rollout_actor,
435406
self.tokenizer,
436407
mesh=self.r2m[Role.ROLLOUT],
437-
model_version=self.cluster_config.rollout_sglang_jax_model_version,
438-
context_length=self.cluster_config.rollout_sglang_jax_context_length,
439-
mem_fraction_static=self.cluster_config.rollout_sglang_jax_mem_fraction_static,
440-
init_with_random_weights=self.cluster_config.rollout_sglang_jax_init_with_random_weights,
441-
disable_radix_cache=self.cluster_config.rollout_sglang_jax_disable_radix_cache,
442-
enable_deterministic_sampling=self.cluster_config.rollout_sglang_jax_enable_deterministic_sampling,
443-
mapping_config=self.cluster_config.rollout_mapping_config,
408+
rollout_config=self.cluster_config.rollout_config,
444409
)
445410

446411
else:

tunix/rl/rollout/base_rollout.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from jax import numpy as jnp
2323
import jaxtyping
2424

25+
from tunix.generate import mappings
2526

2627
@dataclasses.dataclass(frozen=True)
2728
class CacheConfig:
@@ -96,6 +97,53 @@ class RolloutConfig:
9697
# will be used.
9798
eos_tokens: list[int] | None = None
9899

100+
# Weights mapping config for the rollout model.
101+
rollout_mapping_config: mappings.MappingConfig | None = None
102+
103+
# vLLM specific rollout configs.
104+
105+
# Whether to run rollout in vLLM server mode or batch inference mode.
106+
rollout_vllm_server_mode: bool = False
107+
108+
# Model version for vLLM rollout engine.
109+
rollout_vllm_model_version: str = ""
110+
111+
# LoRA config for vLLM rollout engine.
112+
rollout_vllm_lora_config: dict[str, Any] | None = None
113+
114+
# Allocated HBM fraction for vLLM rollout engine.
115+
rollout_vllm_hbm_utilization: float = 0.2
116+
117+
# Whether to initialize vLLM model with random weights or huggingface weights.
118+
rollout_vllm_init_with_random_weights: bool = True
119+
120+
# TPU backend type for vLLM rollout engine, "jax" or "torchax", default to "jax".
121+
rollout_vllm_tpu_backend_type: str | None = None
122+
123+
# Swap space size for vLLM rollout engine, in GiB.
124+
rollout_vllm_swap_space_size_gb: float = 4.0
125+
126+
127+
# SG-Lang JAX specific rollout configs.
128+
129+
# Model version for SG-Lang JAX rollout engine.
130+
rollout_sglang_jax_model_version: str = ""
131+
132+
# Context length for SG-Lang JAX rollout engine.
133+
rollout_sglang_jax_context_length: int = 8192
134+
135+
# Allocated HBM fraction for SG-Lang JAX rollout engine.
136+
rollout_sglang_jax_mem_fraction_static: float = 0.2
137+
138+
# Whether to initialize SG-Lang JAX model with random weights.
139+
rollout_sglang_jax_init_with_random_weights: bool = True
140+
141+
# Radix cache disabling flag for SG-Lang JAX rollout engine. Default to True for RL.
142+
rollout_sglang_jax_disable_radix_cache: bool = True
143+
144+
# Whether to enable deterministic sampling for SG-Lang JAX rollout engine.
145+
rollout_sglang_jax_enable_deterministic_sampling: bool = False
146+
99147

100148
class BaseRollout(abc.ABC):
101149
"""Base RolloutWorker."""

tunix/rl/rollout/sglang_jax_rollout.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,22 @@ def __init__(
3333
model: Any,
3434
tokenizer: Any,
3535
mesh: jax.sharding.Mesh,
36-
model_version: str,
37-
context_length: int,
38-
mem_fraction_static: float,
39-
init_with_random_weights: bool,
40-
disable_radix_cache: bool,
41-
enable_deterministic_sampling: bool,
42-
mapping_config: Optional[mappings.MappingConfig] = None,
43-
rollout_engine: str = "sglang_jax",
36+
rollout_config: base_rollout.RolloutConfig,
4437
):
4538
self.mesh = mesh
4639
mapping_config = mappings.MappingConfig.build(
47-
mapping_obj=mapping_config, model=model, backend=rollout_engine
40+
mapping_obj=rollout_config.rollout_mapping_config, model=model, backend="sglang_jax",
4841
)
4942
self._sampler = sglang_jax_sampler.SglangJaxSampler(
5043
tokenizer=tokenizer,
5144
config=sglang_jax_sampler.SglangJaxConfig(
5245
mesh=mesh,
53-
context_length=context_length,
54-
model_version=model_version,
55-
mem_fraction_static=mem_fraction_static,
56-
init_with_random_weights=init_with_random_weights,
57-
disable_radix_cache=disable_radix_cache,
58-
enable_deterministic_sampling=enable_deterministic_sampling,
46+
context_length=rollout_config.rollout_sglang_jax_context_length,
47+
model_version=rollout_config.rollout_sglang_jax_model_version,
48+
mem_fraction_static=rollout_config.rollout_sglang_jax_mem_fraction_static,
49+
init_with_random_weights=rollout_config.rollout_sglang_jax_init_with_random_weights,
50+
disable_radix_cache=rollout_config.rollout_sglang_jax_disable_radix_cache,
51+
enable_deterministic_sampling=rollout_config.rollout_sglang_jax_enable_deterministic_sampling,
5952
mapping_config=mapping_config,
6053
),
6154
)

tunix/rl/rollout/vllm_rollout.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,33 +33,25 @@ def __init__(
3333
tokenizer: Any,
3434
cache_config_or_size: base_rollout.CacheConfig | int,
3535
mesh: jax.sharding.Mesh,
36-
model_version: str,
37-
hbm_utilization: float,
38-
init_with_random_weights: bool,
39-
tpu_backend_type: str,
40-
swap_space: float = 4.0, # in GiB
41-
server_mode: bool = False,
42-
lora_config: Optional[Dict[str, str]] = None,
43-
mapping_config: Optional[mappings.MappingConfig] = None,
44-
rollout_engine: str = "vllm_jax",
36+
rollout_config: base_rollout.RolloutConfig,
4537
):
4638
self.mesh = mesh
4739
mapping_config = mappings.MappingConfig.build(
48-
mapping_obj=mapping_config, model=model, backend=rollout_engine
40+
mapping_obj=rollout_config.rollout_mapping_config, model=model, backend="vllm_jax",
4941
)
5042
self._sampler = vllm_sampler.VllmSampler(
5143
tokenizer=tokenizer,
5244
config=vllm_sampler.VllmConfig(
5345
max_model_len=cache_config_or_size,
5446
mesh=mesh,
55-
model_version=model_version,
56-
hbm_utilization=hbm_utilization,
57-
init_with_random_weights=init_with_random_weights,
58-
tpu_backend_type=tpu_backend_type,
47+
model_version=rollout_config.rollout_vllm_model_version,
48+
hbm_utilization=rollout_config.rollout_vllm_hbm_utilization,
49+
init_with_random_weights=rollout_config.rollout_vllm_init_with_random_weights,
50+
tpu_backend_type=rollout_config.rollout_vllm_tpu_backend_type,
5951
mapping_config=mapping_config,
60-
lora_config=lora_config,
61-
swap_space=swap_space,
62-
server_mode=server_mode,
52+
lora_config=rollout_config.rollout_vllm_lora_config,
53+
swap_space=rollout_config.rollout_vllm_swap_space_size_gb,
54+
server_mode=rollout_config.rollout_vllm_server_mode,
6355
),
6456
)
6557
state = nnx.state(model)

0 commit comments

Comments
 (0)