Skip to content

Commit 41efcae

Browse files
authored
[Feature] PD-Multiplexing Context and Scheduler, lazy import spatial. (#12275)
1 parent 7056296 commit 41efcae

File tree

9 files changed

+458
-24
lines changed

9 files changed

+458
-24
lines changed

python/sglang/srt/layers/logits_processor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,7 @@ class LogitsMetadata:
134134
@classmethod
135135
def from_forward_batch(cls, forward_batch: ForwardBatch):
136136
if (
137-
(
138-
forward_batch.forward_mode.is_extend()
139-
or forward_batch.forward_mode.is_split_prefill()
140-
)
137+
forward_batch.forward_mode.is_extend()
141138
and forward_batch.return_logprob
142139
and not forward_batch.forward_mode.is_target_verify()
143140
):
@@ -384,8 +381,8 @@ def forward(
384381
input_logprob_indices = None
385382
elif (
386383
logits_metadata.forward_mode.is_extend()
387-
or logits_metadata.forward_mode.is_split_prefill()
388-
) and not logits_metadata.extend_return_logprob:
384+
and not logits_metadata.extend_return_logprob
385+
):
389386
# Prefill without input logprobs.
390387
if logits_metadata.padded_static_len < 0:
391388
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1

python/sglang/srt/managers/schedule_batch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@
7272
from sglang.srt.mem_cache.radix_cache import RadixKey
7373
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
7474
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
75-
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
75+
from sglang.srt.model_executor.forward_batch_info import (
76+
CaptureHiddenMode,
77+
ForwardBatch,
78+
ForwardMode,
79+
)
7680
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
7781
from sglang.srt.sampling.sampling_params import SamplingParams
7882
from sglang.srt.server_args import ServerArgs, get_global_server_args

python/sglang/srt/managers/scheduler.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
153153
from sglang.srt.mem_cache.radix_cache import RadixCache
154154
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
155+
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
155156
from sglang.srt.parser.reasoning_parser import ReasoningParser
156157
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
157158
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -213,6 +214,7 @@ class Scheduler(
213214
SchedulerMetricsMixin,
214215
SchedulerDisaggregationDecodeMixin,
215216
SchedulerDisaggregationPrefillMixin,
217+
SchedulerMultiplexMixin,
216218
SchedulerRuntimeCheckerMixin,
217219
SchedulerPPMixin,
218220
):
@@ -252,6 +254,7 @@ def __init__(
252254
self.enable_lora = server_args.enable_lora
253255
self.max_loras_per_batch = server_args.max_loras_per_batch
254256
self.enable_overlap = not server_args.disable_overlap_schedule
257+
self.enable_pdmux = server_args.enable_pdmux
255258
self.skip_tokenizer_init = server_args.skip_tokenizer_init
256259
self.enable_metrics = server_args.enable_metrics
257260
self.enable_metrics_for_all_schedulers = (
@@ -285,6 +288,10 @@ def __init__(
285288
# Init inter-process communication
286289
self.init_sockets(server_args, port_args)
287290

291+
# Init pdmux context
292+
if self.enable_pdmux:
293+
self.init_pdmux()
294+
288295
# Init tokenizer
289296
self.init_tokenizer()
290297

@@ -424,6 +431,8 @@ def __init__(
424431
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
425432
# The current forward batch
426433
self.cur_batch: Optional[ScheduleBatch] = None
434+
# The current split prefill batch
435+
self.split_prefill_batch: Optional[ScheduleBatch] = None
427436
# The last forward batch
428437
self.last_batch: Optional[ScheduleBatch] = None
429438
self.forward_ct = 0
@@ -1952,7 +1961,6 @@ def run_batch(
19521961

19531962
# Run forward
19541963
if self.is_generation:
1955-
19561964
batch_or_worker_batch = batch
19571965

19581966
if self.enable_overlap or self.spec_algorithm.is_none():
@@ -2009,6 +2017,9 @@ def run_batch(
20092017
# The future value, usually for next batch preparation
20102018
# Current implementation strictly synchronizes the seq_lens
20112019
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
2020+
elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
2021+
batch_result = self.tp_worker.forward_batch_split_prefill(batch)
2022+
future_indices_or_next_token_ids = batch_result.next_token_ids
20122023
else:
20132024
batch_result = self.model_worker.forward_batch_generation(
20142025
batch_or_worker_batch
@@ -2791,7 +2802,9 @@ def run_scheduler_process(
27912802

27922803
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
27932804
if disaggregation_mode == DisaggregationMode.NULL:
2794-
if server_args.pp_size > 1:
2805+
if scheduler.enable_pdmux:
2806+
scheduler.event_loop_pdmux()
2807+
elif server_args.pp_size > 1:
27952808
scheduler.event_loop_pp()
27962809
elif scheduler.enable_overlap:
27972810
scheduler.event_loop_overlap()

python/sglang/srt/managers/tp_worker.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
UpdateWeightsFromIPCReqInput,
3636
UpdateWeightsFromTensorReqInput,
3737
)
38-
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
38+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
3939
from sglang.srt.managers.scheduler import GenerationBatchResult
4040
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
4141
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
@@ -425,3 +425,26 @@ def sample_batch_func():
425425
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
426426
can_run_cuda_graph=can_run_cuda_graph,
427427
)
428+
429+
def forward_batch_split_prefill(self, batch: ScheduleBatch):
430+
if batch.split_index == 0:
431+
model_worker_batch = batch.get_model_worker_batch()
432+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
433+
batch.split_forward_batch = forward_batch
434+
batch.seq_lens_cpu_cache = model_worker_batch.seq_lens_cpu
435+
else:
436+
model_worker_batch = batch.get_model_worker_batch(batch.seq_lens_cpu_cache)
437+
438+
logits_output, can_run_cuda_graph = self.model_runner.forward(
439+
batch.split_forward_batch, split_forward_count=batch.split_forward_count
440+
)
441+
if logits_output:
442+
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
443+
else:
444+
next_token_ids = None
445+
batch_result = GenerationBatchResult(
446+
logits_output=logits_output,
447+
can_run_cuda_graph=can_run_cuda_graph,
448+
)
449+
batch_result.next_token_ids = next_token_ids
450+
return batch_result

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def __init__(
509509
enable_memory_saver: bool,
510510
start_layer: Optional[int] = None,
511511
end_layer: Optional[int] = None,
512+
enable_alt_stream: bool = True,
512513
enable_kv_cache_copy: bool = False,
513514
):
514515
super().__init__(
@@ -527,7 +528,9 @@ def __init__(
527528
self._create_buffers()
528529

529530
self.device_module = torch.get_device_module(self.device)
530-
self.alt_stream = self.device_module.Stream() if _is_cuda else None
531+
self.alt_stream = (
532+
self.device_module.Stream() if _is_cuda and enable_alt_stream else None
533+
)
531534

532535
if enable_kv_cache_copy:
533536
self._init_kv_copy_and_warmup()

python/sglang/srt/model_executor/forward_batch_info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def is_extend(self, include_draft_extend_v2: bool = False):
9696
else False
9797
)
9898
or self == ForwardMode.TARGET_VERIFY
99+
or self == ForwardMode.SPLIT_PREFILL
99100
)
100101

101102
def is_decode(self):

python/sglang/srt/model_executor/model_runner.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,7 @@ def init_memory_pool(
17651765
enable_memory_saver=self.server_args.enable_memory_saver,
17661766
start_layer=self.start_layer,
17671767
end_layer=self.end_layer,
1768+
enable_alt_stream=not self.server_args.enable_pdmux,
17681769
enable_kv_cache_copy=(
17691770
self.server_args.speculative_algorithm is not None
17701771
),
@@ -1833,12 +1834,18 @@ def init_cublas(self):
18331834

18341835
def init_attention_backend(self):
18351836
"""Init attention kernel backend."""
1836-
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1837+
if self.server_args.enable_pdmux:
1838+
self.attn_backend = self._get_attention_backend(init_new_workspace=True)
1839+
self.decode_attn_backend_group = []
1840+
for _ in range(self.server_args.sm_group_num):
1841+
self.decode_attn_backend_group.append(self._get_attention_backend())
1842+
self.decode_attn_backend = self.decode_attn_backend_group[0]
1843+
elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
18371844
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
18381845
else:
18391846
self.attn_backend = self._get_attention_backend()
18401847

1841-
def _get_attention_backend(self):
1848+
def _get_attention_backend(self, init_new_workspace: bool = False):
18421849
"""Init attention kernel backend."""
18431850
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
18441851
self.server_args.get_attention_backends()
@@ -1852,10 +1859,12 @@ def _get_attention_backend(self):
18521859
attn_backend = HybridAttnBackend(
18531860
self,
18541861
decode_backend=self._get_attention_backend_from_str(
1855-
self.decode_attention_backend_str
1862+
self.decode_attention_backend_str,
1863+
init_new_workspace=init_new_workspace,
18561864
),
18571865
prefill_backend=self._get_attention_backend_from_str(
1858-
self.prefill_attention_backend_str
1866+
self.prefill_attention_backend_str,
1867+
init_new_workspace=init_new_workspace,
18591868
),
18601869
)
18611870
logger.info(
@@ -1869,7 +1878,8 @@ def _get_attention_backend(self):
18691878
)
18701879
else:
18711880
attn_backend = self._get_attention_backend_from_str(
1872-
self.server_args.attention_backend
1881+
self.server_args.attention_backend,
1882+
init_new_workspace=init_new_workspace,
18731883
)
18741884

18751885
(
@@ -1878,9 +1888,12 @@ def _get_attention_backend(self):
18781888
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
18791889
return attn_backend
18801890

1881-
def _get_attention_backend_from_str(self, backend_str: str):
1891+
def _get_attention_backend_from_str(
1892+
self, backend_str: str, init_new_workspace: bool = False
1893+
):
18821894
if backend_str not in ATTENTION_BACKENDS:
18831895
raise ValueError(f"Invalid attention backend: {backend_str}")
1896+
self.init_new_workspace = init_new_workspace
18841897
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
18851898
return attn_backend_wrapper(self, full_attention_backend)
18861899

@@ -1978,14 +1991,21 @@ def apply_torch_tp(self):
19781991
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
19791992
tensor_parallel(self.model, device_mesh)
19801993

1994+
def update_decode_attn_backend(self, stream_idx: int):
1995+
self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]
1996+
19811997
def forward_decode(
19821998
self,
19831999
forward_batch: ForwardBatch,
19842000
skip_attn_backend_init: bool = False,
19852001
pp_proxy_tensors=None,
19862002
) -> LogitsProcessorOutput:
19872003
if not skip_attn_backend_init:
1988-
self.attn_backend.init_forward_metadata(forward_batch)
2004+
if self.server_args.enable_pdmux:
2005+
self.decode_attn_backend.init_forward_metadata(forward_batch)
2006+
forward_batch.attn_backend = self.decode_attn_backend
2007+
else:
2008+
self.attn_backend.init_forward_metadata(forward_batch)
19892009
# FIXME: add pp_proxy_tensors arg to all models
19902010
kwargs = {}
19912011
if self.support_pp:
@@ -2123,18 +2143,18 @@ def _forward_raw(
21232143
skip_attn_backend_init=skip_attn_backend_init,
21242144
pp_proxy_tensors=pp_proxy_tensors,
21252145
)
2126-
elif forward_batch.forward_mode.is_extend():
2127-
ret = self.forward_extend(
2128-
forward_batch,
2129-
skip_attn_backend_init=skip_attn_backend_init,
2130-
pp_proxy_tensors=pp_proxy_tensors,
2131-
)
21322146
elif forward_batch.forward_mode.is_split_prefill():
21332147
ret = self.forward_split_prefill(
21342148
forward_batch,
21352149
reinit_attn_backend=reinit_attn_backend,
21362150
forward_count=split_forward_count,
21372151
)
2152+
elif forward_batch.forward_mode.is_extend():
2153+
ret = self.forward_extend(
2154+
forward_batch,
2155+
skip_attn_backend_init=skip_attn_backend_init,
2156+
pp_proxy_tensors=pp_proxy_tensors,
2157+
)
21382158
elif forward_batch.forward_mode.is_idle():
21392159
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
21402160
else:

0 commit comments

Comments
 (0)