3535from  jax .typing  import  ArrayLike   # pylint: disable=g-importing-member 
3636import  jaxtyping 
3737import  optax 
38- from  tunix .generate  import  mappings 
3938# Internal placeholder for sglang_jax rollout worker stub, don't change this line. 
4039# Internal placeholder for vllm rollout worker stub, don't change this line. 
4140from  tunix .rl  import  reshard 
@@ -181,22 +180,7 @@ class ClusterConfig:
181180  rollout_config : (
182181      dict [Mode , base_rollout .RolloutConfig ] |  base_rollout .RolloutConfig 
183182  )
184-   rollout_mapping_config : mappings .MappingConfig  |  None  =  None 
185183
186-   rollout_vllm_server_mode : bool  =  False 
187-   rollout_vllm_model_version : str  =  "" 
188-   rollout_vllm_lora_config : dict [str , Any ] |  None  =  None 
189-   rollout_vllm_hbm_utilization : float  =  0.2 
190-   rollout_vllm_init_with_random_weights : bool  =  True 
191-   rollout_vllm_tpu_backend_type : str  |  None  =  None 
192-   rollout_vllm_swap_space_size_gb : float  =  4.0   # in GiB 
193- 
194-   rollout_sglang_jax_model_version : str  =  "" 
195-   rollout_sglang_jax_context_length : int  =  8192 
196-   rollout_sglang_jax_mem_fraction_static : float  =  0.2 
197-   rollout_sglang_jax_init_with_random_weights : bool  =  True 
198-   rollout_sglang_jax_disable_radix_cache : bool  =  True 
199-   rollout_sglang_jax_enable_deterministic_sampling : bool  =  False 
200184
201185
202186class  RLCluster :
@@ -403,29 +387,16 @@ def _init_cluster(self):
403387    elif  self .cluster_config .rollout_engine  ==  "vllm" :
404388      from  tunix .rl .rollout  import  vllm_rollout 
405389
406-       if  self .cluster_config .rollout_vllm_model_version  is  None :
390+       if  self .cluster_config .rollout_config . rollout_vllm_model_version  is  None :
407391        raise  ValueError ("Rollout vllm model version or path is missing!" )
408392
409-       backend  =  (
410-           self .cluster_config .rollout_engine 
411-           +  "_" 
412-           +  self .cluster_config .rollout_vllm_tpu_backend_type 
413-       )
414393      # TODO(linchai): maybe support offloading for vllm rollout. 
415394      self ._rollout  =  vllm_rollout .VllmRollout (
416395          self .rollout_actor ,
417396          self .tokenizer ,
418397          cache_config_or_size = max_kv_cache_size ,
419398          mesh = self .r2m [Role .ROLLOUT ],
420-           model_version = self .cluster_config .rollout_vllm_model_version ,
421-           hbm_utilization = self .cluster_config .rollout_vllm_hbm_utilization ,
422-           init_with_random_weights = self .cluster_config .rollout_vllm_init_with_random_weights ,
423-           tpu_backend_type = self .cluster_config .rollout_vllm_tpu_backend_type ,
424-           swap_space = self .cluster_config .rollout_vllm_swap_space_size_gb ,
425-           lora_config = self .cluster_config .rollout_vllm_lora_config ,
426-           rollout_engine = backend ,
427-           mapping_config = self .cluster_config .rollout_mapping_config ,
428-           server_mode = self .cluster_config .rollout_vllm_server_mode ,
399+           rollout_config = self .cluster_config .rollout_config ,
429400      )
430401    elif  self .cluster_config .rollout_engine  ==  "sglang_jax" :
431402      from  tunix .rl .rollout  import  sglang_jax_rollout 
@@ -434,13 +405,7 @@ def _init_cluster(self):
434405          self .rollout_actor ,
435406          self .tokenizer ,
436407          mesh = self .r2m [Role .ROLLOUT ],
437-           model_version = self .cluster_config .rollout_sglang_jax_model_version ,
438-           context_length = self .cluster_config .rollout_sglang_jax_context_length ,
439-           mem_fraction_static = self .cluster_config .rollout_sglang_jax_mem_fraction_static ,
440-           init_with_random_weights = self .cluster_config .rollout_sglang_jax_init_with_random_weights ,
441-           disable_radix_cache = self .cluster_config .rollout_sglang_jax_disable_radix_cache ,
442-           enable_deterministic_sampling = self .cluster_config .rollout_sglang_jax_enable_deterministic_sampling ,
443-           mapping_config = self .cluster_config .rollout_mapping_config ,
408+           rollout_config = self .cluster_config .rollout_config ,
444409      )
445410
446411    else :
0 commit comments