Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions vllm_ascend/distributed/mooncake/config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ class MooncakeEngineMetadata:
model_name: str
""" world size when running under a distributed setting """
world_size: int
""" worker id when running under a distributed setting """
worker_id: int
""" Initialize the current PCP's rank """
pcp_rank: int
""" Initialize the current DCP's rank """
dcp_rank: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dcp_rank might be redundant with tp_rank as their logic is similar. We can probably use tp_rank directly and remove dcp_rank.

""" Initialize the current TP's rank """
tp_rank: int
""" the format of kv tensors """
kv_dtype: torch.dtype
""" the shape of kv tensors """
Expand All @@ -39,20 +43,25 @@ class MooncakeEngineMetadata:
class MooncakeEngineKey:
model_name: str
world_size: int
worker_id: int
pcp_rank: int
dcp_rank: int
tp_rank: int
chunk_hash: str

def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.pcp_rank,
self.dcp_rank,
self.tp_rank,
self.chunk_hash,
))

def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}")
f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}"
f"@tp{self.tp_rank}@{self.chunk_hash}")

def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
"""Split the key into multiple keys for each layer"""
Expand All @@ -62,7 +71,9 @@ def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
LayerMooncakeEngineKey(
self.model_name,
self.world_size,
self.worker_id,
self.pcp_rank,
self.dcp_rank,
self.tp_rank,
self.chunk_hash,
layer_id,
))
Expand All @@ -74,7 +85,9 @@ def to_dict(self):
"__type__": "CacheEngineKey",
"model_name": self.model_name,
"world_size": self.world_size,
"worker_id": self.worker_id,
"pcp_rank": self.pcp_rank,
"dcp_rank": self.dcp_rank,
"tp_rank": self.tp_rank,
"chunk_hash": self.chunk_hash,
}

Expand All @@ -83,7 +96,9 @@ def from_dict(d):
return MooncakeEngineKey(
model_name=d["model_name"],
world_size=d["world_size"],
worker_id=d["worker_id"],
pcp_rank=d["pcp_rank"],
dcp_rank=d["dcp_rank"],
tp_rank=d["tp_rank"],
chunk_hash=d["chunk_hash"],
)

Expand All @@ -98,14 +113,17 @@ def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.pcp_rank,
self.dcp_rank,
self.tp_rank,
self.chunk_hash,
self.layer_id,
))

def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}"
f"@tp{self.tp_rank}@{self.chunk_hash}@{self.layer_id}")


class ChunkedTokenDatabase():
Expand All @@ -123,7 +141,9 @@ def _make_key_by_hash(self,
return MooncakeEngineKey(
self.metadata.model_name,
self.metadata.world_size,
self.metadata.worker_id,
self.metadata.pcp_rank,
self.metadata.dcp_rank,
self.metadata.tp_rank,
chunk_hash,
)

Expand Down
35 changes: 30 additions & 5 deletions vllm_ascend/distributed/mooncake/mooncake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.distributed import (get_decode_context_model_parallel_rank,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These get_xxx methods are from a private repository, not from the main branch. They need to be intercepted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has been incorporated

get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.utils import logger

from vllm_ascend.distributed.mooncake.config_data import (
Expand All @@ -16,13 +20,18 @@
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is

if vllm_version_is("0.11.0"):
from vllm.utils import get_kv_cache_torch_dtype
else:
from vllm.utils.torch_utils import get_kv_cache_torch_dtype

if prefill_context_parallel_enable():
from vllm.distributed import (
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size)


class MooncakeEngine:
#The main class for the cache engine.
Expand All @@ -40,19 +49,33 @@ def __init__(
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()

self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0

self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"register_buffer", False)
self.block_size = vllm_config.cache_config.block_size

if self.pcp_size > 1:
self.block_size *= self.pcp_size
if self.dcp_size > 1:
self.block_size *= self.dcp_size

self.current_layer = 0
# self.use_mla = first_kv_cache_tuple[0].size(
# -1) != first_kv_cache_tuple[1].size(-1)
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
num_kv_head = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
kv_dtype = get_kv_cache_torch_dtype(
Expand All @@ -66,7 +89,9 @@ def __init__(
self.metadata = MooncakeEngineMetadata(
model_config.model,
parallel_config.world_size,
parallel_config.rank,
self.pcp_rank,
self.dcp_rank,
self.tp_rank,
kv_dtype,
kv_shape,
self.block_size,
Expand Down
16 changes: 16 additions & 0 deletions vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):

self.kv_caches: dict[str, torch.Tensor] = {}

self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size

self._block_size = vllm_config.cache_config.block_size

if self.pcp_size > 1:
self._block_size *= self.pcp_size
if self.dcp_size > 1:
self._block_size *= self.dcp_size

self.sended_but_unfinished_reqs: set[str] = set()

if role == KVConnectorRole.SCHEDULER:
Expand Down Expand Up @@ -169,7 +177,15 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise):
"load_async", False)
# request_id -> (vllm cached tokes, mooncake cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size

self._block_size = vllm_config.cache_config.block_size

if self.pcp_size > 1:
self._block_size *= self.pcp_size
if self.dcp_size > 1:
self._block_size *= self.dcp_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks
Expand Down
Loading