-
Notifications
You must be signed in to change notification settings - Fork 621
[feature]Pooling Features and PCP Adaptation #4143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
f1062be
bbee9ba
3d276f3
fdd3863
51d7845
b55e3a7
10baba3
14bbd65
2e90200
9ff81be
369060b
10a6d2d
d9153ce
5a98628
3c7e87a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,12 @@ class MooncakeEngineMetadata: | |
| model_name: str | ||
| """ world size when running under a distributed setting """ | ||
| world_size: int | ||
| """ worker id when running under a distributed setting """ | ||
| worker_id: int | ||
| """ Initialize the current PCP's rank """ | ||
| pcp_rank: int | ||
| """ Initialize the current DCP's rank """ | ||
| dcp_rank: int | ||
|
||
| """ Initialize the current TP's rank """ | ||
| tp_rank: int | ||
| """ the format of kv tensors """ | ||
| kv_dtype: torch.dtype | ||
| """ the shape of kv tensors """ | ||
|
|
@@ -39,20 +43,25 @@ class MooncakeEngineMetadata: | |
| class MooncakeEngineKey: | ||
| model_name: str | ||
| world_size: int | ||
| worker_id: int | ||
| pcp_rank: int | ||
| dcp_rank: int | ||
| tp_rank: int | ||
| chunk_hash: str | ||
|
|
||
| def __hash__(self): | ||
| return hash(( | ||
| self.model_name, | ||
| self.world_size, | ||
| self.worker_id, | ||
| self.pcp_rank, | ||
| self.dcp_rank, | ||
| self.tp_rank, | ||
| self.chunk_hash, | ||
| )) | ||
|
|
||
| def to_string(self): | ||
| return (f"{self.model_name}@{self.world_size}" | ||
| f"@{self.worker_id}@{self.chunk_hash}") | ||
| f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}" | ||
| f"@tp{self.tp_rank}@{self.chunk_hash}") | ||
|
|
||
| def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: | ||
| """Split the key into multiple keys for each layer""" | ||
|
|
@@ -62,7 +71,9 @@ def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: | |
| LayerMooncakeEngineKey( | ||
| self.model_name, | ||
| self.world_size, | ||
| self.worker_id, | ||
| self.pcp_rank, | ||
| self.dcp_rank, | ||
| self.tp_rank, | ||
| self.chunk_hash, | ||
| layer_id, | ||
| )) | ||
|
|
@@ -74,7 +85,9 @@ def to_dict(self): | |
| "__type__": "CacheEngineKey", | ||
| "model_name": self.model_name, | ||
| "world_size": self.world_size, | ||
| "worker_id": self.worker_id, | ||
| "pcp_rank": self.pcp_rank, | ||
| "dcp_rank": self.dcp_rank, | ||
| "tp_rank": self.tp_rank, | ||
| "chunk_hash": self.chunk_hash, | ||
| } | ||
|
|
||
|
|
@@ -83,7 +96,9 @@ def from_dict(d): | |
| return MooncakeEngineKey( | ||
| model_name=d["model_name"], | ||
| world_size=d["world_size"], | ||
| worker_id=d["worker_id"], | ||
| pcp_rank=d["pcp_rank"], | ||
| dcp_rank=d["dcp_rank"], | ||
| tp_rank=d["tp_rank"], | ||
| chunk_hash=d["chunk_hash"], | ||
| ) | ||
|
|
||
|
|
@@ -98,14 +113,17 @@ def __hash__(self): | |
| return hash(( | ||
| self.model_name, | ||
| self.world_size, | ||
| self.worker_id, | ||
| self.pcp_rank, | ||
| self.dcp_rank, | ||
| self.tp_rank, | ||
| self.chunk_hash, | ||
| self.layer_id, | ||
| )) | ||
|
|
||
| def to_string(self): | ||
| return (f"{self.model_name}@{self.world_size}" | ||
| f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") | ||
| f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}" | ||
| f"@tp{self.tp_rank}@{self.chunk_hash}@{self.layer_id}") | ||
|
|
||
|
|
||
| class ChunkedTokenDatabase(): | ||
|
|
@@ -123,7 +141,9 @@ def _make_key_by_hash(self, | |
| return MooncakeEngineKey( | ||
| self.metadata.model_name, | ||
| self.metadata.world_size, | ||
| self.metadata.worker_id, | ||
| self.metadata.pcp_rank, | ||
| self.metadata.dcp_rank, | ||
| self.metadata.tp_rank, | ||
| chunk_hash, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,10 @@ | |
| # Third Party | ||
| import torch | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed import (get_decode_context_model_parallel_rank, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These get_xxx methods are from a private repository, not from the main branch. They need to be intercepted.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method has been incorporated |
||
| get_decode_context_model_parallel_world_size, | ||
| get_tensor_model_parallel_rank, | ||
| get_tensor_model_parallel_world_size) | ||
| from vllm.utils import logger | ||
|
|
||
| from vllm_ascend.distributed.mooncake.config_data import ( | ||
|
|
@@ -16,13 +20,18 @@ | |
| KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, | ||
| KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) | ||
| from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore | ||
| from vllm_ascend.utils import vllm_version_is | ||
| from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is | ||
|
|
||
| if vllm_version_is("0.11.0"): | ||
| from vllm.utils import get_kv_cache_torch_dtype | ||
| else: | ||
| from vllm.utils.torch_utils import get_kv_cache_torch_dtype | ||
|
|
||
| if prefill_context_parallel_enable(): | ||
| from vllm.distributed import ( | ||
| get_prefill_context_model_parallel_rank, | ||
| get_prefill_context_model_parallel_world_size) | ||
|
|
||
|
|
||
| class MooncakeEngine: | ||
| #The main class for the cache engine. | ||
|
|
@@ -40,19 +49,33 @@ def __init__( | |
| and model_config.use_mla): | ||
| self.use_mla = True | ||
| self.use_layerwise = use_layerwize | ||
| self.tp_rank = parallel_config.rank | ||
| self.tp_size = parallel_config.tensor_parallel_size | ||
| self.tp_rank = get_tensor_model_parallel_rank() | ||
| self.tp_size = get_tensor_model_parallel_world_size() | ||
|
|
||
| self.pcp_size = get_prefill_context_model_parallel_world_size( | ||
| ) if prefill_context_parallel_enable() else 1 | ||
| self.pcp_rank = get_prefill_context_model_parallel_rank( | ||
| ) if self.pcp_size > 1 else 0 | ||
| self.dcp_size = get_decode_context_model_parallel_world_size() | ||
| self.dcp_rank = get_decode_context_model_parallel_rank( | ||
| ) if self.dcp_size > 1 else 0 | ||
|
|
||
| self.kv_role = vllm_config.kv_transfer_config.kv_role | ||
| self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( | ||
| "load_async", False) | ||
| self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( | ||
| "register_buffer", False) | ||
| self.block_size = vllm_config.cache_config.block_size | ||
|
|
||
| if self.pcp_size > 1: | ||
| self.block_size *= self.pcp_size | ||
| if self.dcp_size > 1: | ||
| self.block_size *= self.dcp_size | ||
|
|
||
| self.current_layer = 0 | ||
| # self.use_mla = first_kv_cache_tuple[0].size( | ||
| # -1) != first_kv_cache_tuple[1].size(-1) | ||
| self.num_layers = model_config.get_num_layers(parallel_config) | ||
| self.block_size = vllm_config.cache_config.block_size | ||
| num_kv_head = model_config.get_num_kv_heads(parallel_config) | ||
| head_size = model_config.get_head_size() | ||
| kv_dtype = get_kv_cache_torch_dtype( | ||
|
|
@@ -66,7 +89,9 @@ def __init__( | |
| self.metadata = MooncakeEngineMetadata( | ||
| model_config.model, | ||
| parallel_config.world_size, | ||
| parallel_config.rank, | ||
| self.pcp_rank, | ||
| self.dcp_rank, | ||
| self.tp_rank, | ||
| kv_dtype, | ||
| kv_shape, | ||
| self.block_size, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.