diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py index 9f4833555db..4107afdfab5 100644 --- a/vllm_ascend/distributed/kvpool/ascend_store_connector.py +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -43,8 +43,6 @@ def __init__(self, self.kv_caches: dict[str, torch.Tensor] = {} - self._block_size = vllm_config.cache_config.block_size - self.sended_but_unfinished_reqs: set[str] = set() if role == KVConnectorRole.SCHEDULER: diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index e3b0873d686..0d89021bb3a 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -17,6 +17,10 @@ class KeyMetadata: model_name: str """ worker id when running under a distributed setting """ head_or_tp_rank: int + """ Initialize the current prefill context model parallel rank """ + pcp_rank: int + """ Initialize the current decode context model parallel rank """ + dcp_rank: int @dataclass(order=True) @@ -28,12 +32,15 @@ def __hash__(self): return hash(( self.key_metadata.model_name, self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, self.chunk_hash, )) def to_string(self): return ( f"{self.key_metadata.model_name}" + f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" ) @@ -60,6 +67,8 @@ def __hash__(self): return hash(( self.key_metadata.model_name, self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, self.chunk_hash, self.layer_id, )) @@ -67,6 +76,7 @@ def __hash__(self): def to_string(self): return ( f"{self.key_metadata.model_name}" + f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}" ) diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index b30158ae8c2..cd5245364c4 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -19,11 +19,13 @@ class KVTransferThread(threading.Thread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event, name: str): + tp_rank: int, dcp_size: int, ready_event: threading.Event, + name: str): super().__init__(daemon=True, name=name) self.m_store = m_store self.ready_event = ready_event self.tp_rank = tp_rank + self.dcp_size = dcp_size self.token_database = token_database self.done_task_lock = threading.Lock() self.request_queue: queue.Queue[Any] = queue.Queue() @@ -87,10 +89,12 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, put_step: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheSendingThread") self.put_step = put_step @@ -112,12 +116,18 @@ def _handle_request(self, req_meta: dict[str, Any]): key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] - if key_list_tp: + if self.dcp_size > 1: torch.npu.current_stream().synchronize() - self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + self.m_store.put(key_list, addr_list, size_list) + else: + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) if is_last_chunk: self.set_finished_request(req_id) self.request_queue.task_done() @@ -126,10 +136,11 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreRecvingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreRecvingThread") @@ -166,11 +177,12 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreLayerSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, put_step: int, ready_event: threading.Event, - num_layers: int): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event, num_layers: int): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreLayerSendingThread") self.final_layer_id = num_layers - 1 @@ -192,12 +204,18 @@ def _handle_request( # type: ignore[override] key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] - if key_list_tp: + if self.dcp_size > 1: torch.npu.current_stream().synchronize() - self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + self.m_store.put(key_list, addr_list, size_list) + else: + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk: self.set_finished_request(req_meta.req_id) self.request_queue.task_done() @@ -206,11 +224,12 @@ def _handle_request( # type: ignore[override] class KVCacheStoreLayerRecvingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event, + tp_rank: int, dcp_size: int, ready_event: threading.Event, get_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread") self.get_event = get_event diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 06041b5a6e5..d1564ce7ec0 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -29,7 +29,14 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): "load_async", False) # request_id -> (vllm cached tokes, kvpool cached tokens) self.load_specs: dict[str, LoadSpec] = {} + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self._block_size = vllm_config.cache_config.block_size + if self.pcp_size > 1: + self._block_size *= self.pcp_size + if self.dcp_size > 1: + self._block_size *= self.dcp_size # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} # Whether to discard partial chunks diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index b03d2808928..25322c5f75d 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -1,11 +1,13 @@ -# Standard import math import threading from typing import Dict, Generator, Optional, Type -# Third Party import torch from vllm.config import VllmConfig +from vllm.distributed import (get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.utils import logger from vllm.v1.core.kv_cache_utils import BlockHash @@ -20,6 +22,14 @@ from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) +from vllm_ascend.utils import prefill_context_parallel_enable + +if prefill_context_parallel_enable(): + # isort: off + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) + # isort: on backend_map: Dict[str, Type[Backend]] = { "mooncake": MooncakeBackend, @@ -44,17 +54,30 @@ def __init__( and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize - self.tp_rank = parallel_config.rank - self.tp_size = parallel_config.tensor_parallel_size + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size + + if self.pcp_size > 1: + self.block_size *= self.pcp_size + if self.dcp_size > 1: + self.block_size *= self.dcp_size self.current_layer = 0 self.num_layers = model_config.get_num_layers(parallel_config) - self.block_size = vllm_config.cache_config.block_size if self.use_mla: self.num_kv_head = 1 @@ -69,8 +92,10 @@ def __init__( self.put_step = 1 self.metadata = KeyMetadata( - model_config.model, + model_config.model.split('/')[-1], self.head_or_tp_rank, + self.pcp_rank, + self.dcp_rank, ) self.token_database = ChunkedTokenDatabase(self.metadata, @@ -147,12 +172,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( self.m_store, self.token_database, self.tp_rank, - self.put_step, ready_event_sending, self.num_layers) + self.dcp_size, self.put_step, ready_event_sending, + self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.m_store, self.token_database, self.tp_rank, ready_event, - self.get_event) + self.m_store, self.token_database, self.tp_rank, self.dcp_size, + ready_event, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: @@ -160,13 +186,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( self.m_store, self.token_database, self.tp_rank, - self.put_step, ready_event_sending) + self.dcp_size, self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.m_store, self.token_database, self.tp_rank, - ready_event) + self.dcp_size, ready_event) self.kv_recv_thread.start() ready_event.wait()