@@ -1765,6 +1765,7 @@ def init_memory_pool(
17651765 enable_memory_saver = self .server_args .enable_memory_saver ,
17661766 start_layer = self .start_layer ,
17671767 end_layer = self .end_layer ,
1768+ enable_alt_stream = not self .server_args .enable_pdmux ,
17681769 enable_kv_cache_copy = (
17691770 self .server_args .speculative_algorithm is not None
17701771 ),
@@ -1833,12 +1834,18 @@ def init_cublas(self):
18331834
18341835 def init_attention_backend (self ):
18351836 """Init attention kernel backend."""
1836- if self .server_args .enable_two_batch_overlap and not self .is_draft_worker :
1837+ if self .server_args .enable_pdmux :
1838+ self .attn_backend = self ._get_attention_backend (init_new_workspace = True )
1839+ self .decode_attn_backend_group = []
1840+ for _ in range (self .server_args .sm_group_num ):
1841+ self .decode_attn_backend_group .append (self ._get_attention_backend ())
1842+ self .decode_attn_backend = self .decode_attn_backend_group [0 ]
1843+ elif self .server_args .enable_two_batch_overlap and not self .is_draft_worker :
18371844 self .attn_backend = TboAttnBackend .init_new (self ._get_attention_backend )
18381845 else :
18391846 self .attn_backend = self ._get_attention_backend ()
18401847
1841- def _get_attention_backend (self ):
1848+ def _get_attention_backend (self , init_new_workspace : bool = False ):
18421849 """Init attention kernel backend."""
18431850 self .prefill_attention_backend_str , self .decode_attention_backend_str = (
18441851 self .server_args .get_attention_backends ()
@@ -1852,10 +1859,12 @@ def _get_attention_backend(self):
18521859 attn_backend = HybridAttnBackend (
18531860 self ,
18541861 decode_backend = self ._get_attention_backend_from_str (
1855- self .decode_attention_backend_str
1862+ self .decode_attention_backend_str ,
1863+ init_new_workspace = init_new_workspace ,
18561864 ),
18571865 prefill_backend = self ._get_attention_backend_from_str (
1858- self .prefill_attention_backend_str
1866+ self .prefill_attention_backend_str ,
1867+ init_new_workspace = init_new_workspace ,
18591868 ),
18601869 )
18611870 logger .info (
@@ -1869,7 +1878,8 @@ def _get_attention_backend(self):
18691878 )
18701879 else :
18711880 attn_backend = self ._get_attention_backend_from_str (
1872- self .server_args .attention_backend
1881+ self .server_args .attention_backend ,
1882+ init_new_workspace = init_new_workspace ,
18731883 )
18741884
18751885 (
@@ -1878,9 +1888,12 @@ def _get_attention_backend(self):
18781888 ) = (self .prefill_attention_backend_str , self .decode_attention_backend_str )
18791889 return attn_backend
18801890
1881- def _get_attention_backend_from_str (self , backend_str : str ):
1891+ def _get_attention_backend_from_str (
1892+ self , backend_str : str , init_new_workspace : bool = False
1893+ ):
18821894 if backend_str not in ATTENTION_BACKENDS :
18831895 raise ValueError (f"Invalid attention backend: { backend_str } " )
1896+ self .init_new_workspace = init_new_workspace
18841897 full_attention_backend = ATTENTION_BACKENDS [backend_str ](self )
18851898 return attn_backend_wrapper (self , full_attention_backend )
18861899
@@ -1978,14 +1991,21 @@ def apply_torch_tp(self):
19781991 device_mesh = torch .distributed .init_device_mesh (self .device , (self .tp_size ,))
19791992 tensor_parallel (self .model , device_mesh )
19801993
1994+ def update_decode_attn_backend (self , stream_idx : int ):
1995+ self .decode_attn_backend = self .decode_attn_backend_group [stream_idx ]
1996+
19811997 def forward_decode (
19821998 self ,
19831999 forward_batch : ForwardBatch ,
19842000 skip_attn_backend_init : bool = False ,
19852001 pp_proxy_tensors = None ,
19862002 ) -> LogitsProcessorOutput :
19872003 if not skip_attn_backend_init :
1988- self .attn_backend .init_forward_metadata (forward_batch )
2004+ if self .server_args .enable_pdmux :
2005+ self .decode_attn_backend .init_forward_metadata (forward_batch )
2006+ forward_batch .attn_backend = self .decode_attn_backend
2007+ else :
2008+ self .attn_backend .init_forward_metadata (forward_batch )
19892009 # FIXME: add pp_proxy_tensors arg to all models
19902010 kwargs = {}
19912011 if self .support_pp :
@@ -2123,18 +2143,18 @@ def _forward_raw(
21232143 skip_attn_backend_init = skip_attn_backend_init ,
21242144 pp_proxy_tensors = pp_proxy_tensors ,
21252145 )
2126- elif forward_batch .forward_mode .is_extend ():
2127- ret = self .forward_extend (
2128- forward_batch ,
2129- skip_attn_backend_init = skip_attn_backend_init ,
2130- pp_proxy_tensors = pp_proxy_tensors ,
2131- )
21322146 elif forward_batch .forward_mode .is_split_prefill ():
21332147 ret = self .forward_split_prefill (
21342148 forward_batch ,
21352149 reinit_attn_backend = reinit_attn_backend ,
21362150 forward_count = split_forward_count ,
21372151 )
2152+ elif forward_batch .forward_mode .is_extend ():
2153+ ret = self .forward_extend (
2154+ forward_batch ,
2155+ skip_attn_backend_init = skip_attn_backend_init ,
2156+ pp_proxy_tensors = pp_proxy_tensors ,
2157+ )
21382158 elif forward_batch .forward_mode .is_idle ():
21392159 ret = self .forward_idle (forward_batch , pp_proxy_tensors = pp_proxy_tensors )
21402160 else :
0 commit comments