From f1062beb4e62026d38ca3b1ae8c2b76924479feb Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Wed, 12 Nov 2025 11:57:15 +0800 Subject: [PATCH 01/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- .../distributed/mooncake/config_data.py | 42 ++++++++++++++----- .../distributed/mooncake/mooncake_engine.py | 37 +++++++++++++--- .../mooncake/mooncake_store_connector_v1.py | 19 +++++++++ 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 36c820b0890..97b1d76694f 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -23,8 +23,12 @@ class MooncakeEngineMetadata: model_name: str """ world size when running under a distributed setting """ world_size: int - """ worker id when running under a distributed setting """ - worker_id: int + """ Initialize the current PCP's rank """ + pcp_rank: int + """ Initialize the current DCP's rank """ + dcp_rank: int + """ Initialize the current TP's rank """ + tp_rank: int """ the format of kv tensors """ kv_dtype: torch.dtype """ the shape of kv tensors """ @@ -39,20 +43,25 @@ class MooncakeEngineMetadata: class MooncakeEngineKey: model_name: str world_size: int - worker_id: int + pcp_rank: int + dcp_rank: int + tp_rank: int chunk_hash: str def __hash__(self): return hash(( self.model_name, self.world_size, - self.worker_id, + self.pcp_rank, + self.dcp_rank, + self.tp_rank, self.chunk_hash, )) def to_string(self): return (f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}") + f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}" + f"@tp{self.tp_rank}@{self.chunk_hash}") def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: """Split the key into multiple keys for each layer""" @@ -62,7 +71,9 @@ def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: LayerMooncakeEngineKey( self.model_name, self.world_size, - self.worker_id, + self.pcp_rank, + self.dcp_rank, + self.tp_rank, self.chunk_hash, layer_id, )) @@ -74,7 +85,9 @@ def to_dict(self): "__type__": "CacheEngineKey", "model_name": self.model_name, "world_size": self.world_size, - "worker_id": self.worker_id, + "pcp_rank": self.pcp_rank, + "dcp_rank": self.dcp_rank, + "tp_rank": self.tp_rank, "chunk_hash": self.chunk_hash, } @@ -83,7 +96,9 @@ def from_dict(d): return MooncakeEngineKey( model_name=d["model_name"], world_size=d["world_size"], - worker_id=d["worker_id"], + pcp_rank=d["pcp_rank"], + dcp_rank=d["dcp_rank"], + tp_rank=d["tp_rank"], chunk_hash=d["chunk_hash"], ) @@ -98,14 +113,17 @@ def __hash__(self): return hash(( self.model_name, self.world_size, - self.worker_id, + self.pcp_rank, + self.dcp_rank, + self.tp_rank, self.chunk_hash, self.layer_id, )) def to_string(self): return (f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") + f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}" + f"@tp{self.tp_rank}@{self.chunk_hash}@{self.layer_id}") class ChunkedTokenDatabase(): @@ -123,7 +141,9 @@ def _make_key_by_hash(self, return MooncakeEngineKey( self.metadata.model_name, self.metadata.world_size, - self.metadata.worker_id, + self.metadata.pcp_rank, + self.metadata.dcp_rank, + self.metadata.tp_rank, chunk_hash, ) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index ac00e22cbfe..7a8bc70b4b5 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -8,6 +8,10 @@ import torch from vllm.config import VllmConfig from vllm.utils import logger +from vllm.distributed import (get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm_ascend.distributed.mooncake.config_data import ( ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, @@ -16,13 +20,18 @@ KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore -from vllm_ascend.utils import vllm_version_is +from vllm_ascend.utils import (vllm_version_is, prefill_context_parallel_enable) if vllm_version_is("0.11.0"): from vllm.utils import get_kv_cache_torch_dtype else: from vllm.utils.torch_utils import get_kv_cache_torch_dtype +if prefill_context_parallel_enable(): + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) + class MooncakeEngine: #The main class for the cache engine. @@ -40,19 +49,35 @@ def __init__( and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize - self.tp_rank = parallel_config.rank - self.tp_size = parallel_config.tensor_parallel_size + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "register_buffer", False) self.block_size = vllm_config.cache_config.block_size + + if self.pcp_size > 1: + self.block_size *= self.pcp_size + + if self.dcp_size > 1: + self.block_size *= self.dcp_size + self.current_layer = 0 # self.use_mla = first_kv_cache_tuple[0].size( # -1) != first_kv_cache_tuple[1].size(-1) self.num_layers = model_config.get_num_layers(parallel_config) - self.block_size = vllm_config.cache_config.block_size num_kv_head = model_config.get_num_kv_heads(parallel_config) head_size = model_config.get_head_size() kv_dtype = get_kv_cache_torch_dtype( @@ -66,7 +91,9 @@ def __init__( self.metadata = MooncakeEngineMetadata( model_config.model, parallel_config.world_size, - parallel_config.rank, + self.pcp_rank, + self.dcp_rank, + self.tp_rank, kv_dtype, kv_shape, self.block_size, diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index f55dd03bbe5..c0f1a45c169 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -31,8 +31,17 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.kv_caches: dict[str, torch.Tensor] = {} + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self._block_size = vllm_config.cache_config.block_size + if self.pcp_size > 1: + self._block_size *= self.pcp_size + + if self.dcp_size > 1: + self._block_size *= self.dcp_size + self.sended_but_unfinished_reqs: set[str] = set() if role == KVConnectorRole.SCHEDULER: @@ -169,7 +178,17 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): "load_async", False) # request_id -> (vllm cached tokes, mooncake cached tokens) self.load_specs: dict[str, LoadSpec] = {} + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self._block_size = vllm_config.cache_config.block_size + + if self.pcp_size > 1: + self._block_size *= self.pcp_size + + if self.dcp_size > 1: + self._block_size *= self.dcp_size + # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} # Whether to discard partial chunks From bbee9bae2ccd7d58c4d6d9b86bc0616e7244548f Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Wed, 12 Nov 2025 14:25:56 +0800 Subject: [PATCH 02/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/mooncake/config_data.py | 4 ++-- vllm_ascend/distributed/mooncake/mooncake_engine.py | 12 +++++------- .../mooncake/mooncake_store_connector_v1.py | 2 -- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 97b1d76694f..7fba17eb9be 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -54,7 +54,7 @@ def __hash__(self): self.world_size, self.pcp_rank, self.dcp_rank, - self.tp_rank, + self.tp_rank, self.chunk_hash, )) @@ -73,7 +73,7 @@ def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: self.world_size, self.pcp_rank, self.dcp_rank, - self.tp_rank, + self.tp_rank, self.chunk_hash, layer_id, )) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 7a8bc70b4b5..ff2cbfd1894 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -7,11 +7,11 @@ # Third Party import torch from vllm.config import VllmConfig -from vllm.utils import logger from vllm.distributed import (get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.utils import logger from vllm_ascend.distributed.mooncake.config_data import ( ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, @@ -20,7 +20,7 @@ KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore -from vllm_ascend.utils import (vllm_version_is, prefill_context_parallel_enable) +from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is if vllm_version_is("0.11.0"): from vllm.utils import get_kv_cache_torch_dtype @@ -28,9 +28,9 @@ from vllm.utils.torch_utils import get_kv_cache_torch_dtype if prefill_context_parallel_enable(): - from vllm.distributed import (get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size - ) + from vllm.distributed import ( + get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size) class MooncakeEngine: @@ -49,7 +49,6 @@ def __init__( and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize - self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -70,7 +69,6 @@ def __init__( if self.pcp_size > 1: self.block_size *= self.pcp_size - if self.dcp_size > 1: self.block_size *= self.dcp_size diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index c0f1a45c169..aa75168e4f1 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -38,7 +38,6 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): if self.pcp_size > 1: self._block_size *= self.pcp_size - if self.dcp_size > 1: self._block_size *= self.dcp_size @@ -185,7 +184,6 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): if self.pcp_size > 1: self._block_size *= self.pcp_size - if self.dcp_size > 1: self._block_size *= self.dcp_size From 3d276f38076d678f7d663b1760d0277f53d410dd Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Wed, 12 Nov 2025 14:35:22 +0800 Subject: [PATCH 03/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index aa75168e4f1..bf41180afc4 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -186,7 +186,6 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): self._block_size *= self.pcp_size if self.dcp_size > 1: self._block_size *= self.dcp_size - # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} # Whether to discard partial chunks From fdd3863f7a538014c4d00c7a8eefc9856a8e445d Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Wed, 12 Nov 2025 15:01:33 +0800 Subject: [PATCH 04/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/mooncake/mooncake_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index ff2cbfd1894..f3c41f809a5 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -28,9 +28,9 @@ from vllm.utils.torch_utils import get_kv_cache_torch_dtype if prefill_context_parallel_enable(): - from vllm.distributed import ( - get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size) + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) class MooncakeEngine: From 51d784539390136b026cbea64377bdc257d48b9c Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Wed, 12 Nov 2025 15:28:12 +0800 Subject: [PATCH 05/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/mooncake/config_data.py | 6 +++--- vllm_ascend/distributed/mooncake/mooncake_engine.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 7fba17eb9be..d67a141433e 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -23,11 +23,11 @@ class MooncakeEngineMetadata: model_name: str """ world size when running under a distributed setting """ world_size: int - """ Initialize the current PCP's rank """ + """ Initialize the current prefill context model parallel rank """ pcp_rank: int - """ Initialize the current DCP's rank """ + """ Initialize the current decode context model parallel rank """ dcp_rank: int - """ Initialize the current TP's rank """ + """ Initialize the current tensor parallel rank """ tp_rank: int """ the format of kv tensors """ kv_dtype: torch.dtype diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index f3c41f809a5..ff2cbfd1894 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -28,9 +28,9 @@ from vllm.utils.torch_utils import get_kv_cache_torch_dtype if prefill_context_parallel_enable(): - from vllm.distributed import (get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size - ) + from vllm.distributed import ( + get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size) class MooncakeEngine: From 10baba34f8d135d52d3074b3706adf2402741a50 Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Sat, 29 Nov 2025 09:56:39 +0800 Subject: [PATCH 06/13] Synchronous code Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/ascend_config.py | 16 + vllm_ascend/ascend_forward_context.py | 27 +- vllm_ascend/attention/attention_mask.py | 2 - vllm_ascend/attention/attention_v1.py | 80 ++- vllm_ascend/attention/mla_v1.py | 176 ++++- vllm_ascend/attention/sfa_v1.py | 8 +- vllm_ascend/compilation/acl_graph.py | 55 +- vllm_ascend/core/recompute_scheduler.py | 20 +- vllm_ascend/core/scheduler.py | 44 +- vllm_ascend/core/scheduler_dynamic_batch.py | 37 +- vllm_ascend/distributed/__init__.py | 9 +- .../distributed/cpu_offload_connector.py | 6 +- .../cpu_offload_manager/metadata.py | 9 +- .../llmdatadist_c_mgr_connector.py | 20 +- vllm_ascend/distributed/mooncake/__init__.py | 0 .../distributed/mooncake/config_data.py | 561 --------------- .../distributed/mooncake/kv_transfer.py | 282 -------- .../distributed/mooncake/mooncake_engine.py | 652 ----------------- .../distributed/mooncake/mooncake_store.py | 126 ---- .../mooncake/mooncake_store_connector_v1.py | 514 ------------- .../distributed/mooncake/transfer_engine.py | 28 - vllm_ascend/distributed/mooncake_connector.py | 246 +++++-- .../mooncake_layerwise_connector.py | 13 +- vllm_ascend/envs.py | 16 +- vllm_ascend/eplb/core/eplb_utils.py | 34 +- vllm_ascend/kv_offload/cpu_npu.py | 8 +- vllm_ascend/lora/punica_npu.py | 65 +- .../model_loader/netloader/netloader.py | 14 +- vllm_ascend/models/__init__.py | 19 +- vllm_ascend/models/layers/mla.py | 104 +-- vllm_ascend/models/qwen2_5_vl.py | 572 --------------- .../models/qwen2_5_vl_without_padding.py | 605 ---------------- vllm_ascend/models/qwen2_vl.py | 12 +- vllm_ascend/models/qwen3_next.py | 395 ++++++++-- vllm_ascend/ops/activation.py | 4 +- vllm_ascend/ops/casual_conv1d.py | 539 -------------- vllm_ascend/ops/expert_load_balancer.py | 8 +- vllm_ascend/ops/fla.py | 299 -------- vllm_ascend/ops/fused_moe/experts_selector.py | 79 +- vllm_ascend/ops/fused_moe/fused_moe.py | 115 +-- vllm_ascend/ops/fused_moe/moe_comm_method.py | 33 +- vllm_ascend/ops/fused_moe/moe_mlp.py | 5 +- vllm_ascend/ops/fused_moe/prepare_finalize.py | 114 +-- vllm_ascend/ops/fused_moe/token_dispatcher.py | 6 +- vllm_ascend/ops/layernorm.py | 11 +- vllm_ascend/ops/linear.py | 3 +- vllm_ascend/ops/register_custom_ops.py | 9 +- vllm_ascend/ops/rotary_embedding.py | 16 +- vllm_ascend/ops/sigmoid_gating.py | 300 -------- vllm_ascend/patch/__init__.py | 24 +- vllm_ascend/patch/platform/__init__.py | 3 +- .../patch/platform/patch_distributed.py | 4 +- .../patch/platform/patch_mamba_config.py | 17 +- .../platform/patch_multiproc_executor.py | 103 +-- vllm_ascend/patch/worker/__init__.py | 9 +- .../patch/worker/patch_deepseek_mtp.py | 54 -- .../patch/worker/patch_deepseek_v3_2.py | 108 --- vllm_ascend/patch/worker/patch_logits.py | 26 - vllm_ascend/patch/worker/patch_triton.py | 23 +- .../patch/worker/patch_weight_loader.py | 8 +- vllm_ascend/platform.py | 222 +++--- vllm_ascend/quantization/quant_config.py | 49 +- vllm_ascend/quantization/utils.py | 27 +- vllm_ascend/quantization/w4a8_dynamic.py | 5 +- vllm_ascend/quantization/w8a8.py | 27 +- vllm_ascend/quantization/w8a8_dynamic.py | 31 +- vllm_ascend/sample/rejection_sampler.py | 27 +- vllm_ascend/sample/sampler.py | 5 +- vllm_ascend/spec_decode/eagle_proposer.py | 45 +- vllm_ascend/spec_decode/interface.py | 3 +- vllm_ascend/spec_decode/mtp_proposer.py | 229 ++++-- vllm_ascend/spec_decode/ngram_proposer.py | 5 +- vllm_ascend/torchair/models/qwen2.py | 6 +- vllm_ascend/torchair/models/qwen3_moe.py | 24 +- .../torchair/models/torchair_deepseek_v2.py | 182 ++--- .../torchair/models/torchair_pangu_moe.py | 17 +- .../torchair/ops/torchair_activation.py | 4 +- .../torchair/ops/torchair_fused_moe.py | 42 +- .../torchair/ops/torchair_layernorm.py | 4 +- .../torchair/ops/torchair_rotary_embedding.py | 8 +- .../quantization/torchair_w8a8_dynamic.py | 12 +- vllm_ascend/torchair/torchair_attention.py | 21 +- vllm_ascend/torchair/torchair_mla.py | 18 +- vllm_ascend/torchair/torchair_model_runner.py | 28 +- vllm_ascend/torchair/torchair_mtp_proposer.py | 24 +- vllm_ascend/torchair/torchair_sfa.py | 16 +- vllm_ascend/utils.py | 148 ++-- vllm_ascend/worker/block_table.py | 49 +- vllm_ascend/worker/model_runner_v1.py | 678 ++++++++---------- vllm_ascend/worker/npu_input_batch.py | 9 +- vllm_ascend/worker/worker_v1.py | 49 +- 91 files changed, 2067 insertions(+), 6642 deletions(-) delete mode 100644 vllm_ascend/distributed/mooncake/__init__.py delete mode 100644 vllm_ascend/distributed/mooncake/config_data.py delete mode 100644 vllm_ascend/distributed/mooncake/kv_transfer.py delete mode 100644 vllm_ascend/distributed/mooncake/mooncake_engine.py delete mode 100644 vllm_ascend/distributed/mooncake/mooncake_store.py delete mode 100644 vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py delete mode 100644 vllm_ascend/distributed/mooncake/transfer_engine.py delete mode 100644 vllm_ascend/models/qwen2_5_vl.py delete mode 100644 vllm_ascend/models/qwen2_5_vl_without_padding.py delete mode 100644 vllm_ascend/ops/casual_conv1d.py delete mode 100644 vllm_ascend/ops/fla.py delete mode 100644 vllm_ascend/ops/sigmoid_gating.py delete mode 100644 vllm_ascend/patch/worker/patch_deepseek_mtp.py delete mode 100644 vllm_ascend/patch/worker/patch_deepseek_v3_2.py delete mode 100644 vllm_ascend/patch/worker/patch_logits.py diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 1fd1c67cdf7..16d16a4d7c8 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -44,6 +44,10 @@ def __init__(self, vllm_config): self.ascend_scheduler_config = AscendSchedulerConfig( ascend_scheduler_config) + # Dump / PrecisionDebugger configuration + dump_config_path = additional_config.get("dump_config", None) + self.dump_config = DumpConfig(dump_config_path) + weight_prefetch_config = additional_config.get( "weight_prefetch_config", {}) self.weight_prefetch_config = WeightPrefetchConfig( @@ -230,6 +234,18 @@ def __init__(self, ascend_scheduler_config: dict): setattr(self, k, v) +class DumpConfig: + """ + Configuration object for dump/PrecisionDebugger settings. + """ + + def __init__(self, dump_config_path: Optional[str] = None): + # enable_dump is True when dump_cfg exists and config_path is not empty + self.enable_dump: bool = bool(dump_config_path) + # Path to msprobe config json; may be None. + self.config_path: Optional[str] = dump_config_path + + class WeightPrefetchConfig: """ Configuration Object for weight_prefetch_config from additional_config diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 65bc5a472c0..11c1d3a0373 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -29,16 +29,8 @@ class FusedMoEState(Enum): All2AllSeq = 5 -class MoECommType(Enum): - ALLGATHER = 0 - MC2 = 1 - ALLTOALL = 2 - NAIVE_MULTICAST = 3 - - -# TODO(zzzzwwjj): add soc_version to choose branch -def _get_fused_moe_state(ep_size: int, with_prefill: bool, - is_deepseek_v3_r1: bool): +def get_fused_moe_state(ep_size: int, with_prefill: bool, + is_deepseek_v3_r1: bool): # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep # only supports deepseek v3/r1 if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 @@ -56,6 +48,12 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool, return FusedMoEState.MC2 +class MoECommType(Enum): + ALLGATHER = 0 + MC2 = 1 + ALLTOALL = 2 + + @contextmanager def set_ascend_forward_context( attn_metadata: Any, @@ -72,7 +70,8 @@ def set_ascend_forward_context( batch_descriptor: Optional[BatchDescriptor] = None, prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, - weight_prefetch_method: Optional[WeightPrefetchMethod] = None): + weight_prefetch_method: Optional[WeightPrefetchMethod] = None, + is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. @@ -98,11 +97,12 @@ def set_ascend_forward_context( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) + # fused_moe_state is used in torchair, it will be deleted along with torchair is_deepseek_v3_r1 = hasattr( vllm_config.model_config.hf_config, 'n_routed_experts' ) and vllm_config.model_config.hf_config.n_routed_experts == 256 - fused_moe_state = _get_fused_moe_state(ep_size, with_prefill, - is_deepseek_v3_r1) + fused_moe_state = get_fused_moe_state(ep_size, with_prefill, + is_deepseek_v3_r1) forward_context.fused_moe_state = fused_moe_state forward_context.in_profile_run = in_profile_run @@ -157,6 +157,7 @@ def set_ascend_forward_context( forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.model_instance = model_instance forward_context.weight_prefetch_method = weight_prefetch_method + forward_context.is_mtp_model = is_mtp_model # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. # It will be improved later by implementing operator fusion through the FX graph. diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 3514984d826..2c963b5ce28 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -67,8 +67,6 @@ def get_mask_scale_factor(dtype: torch.dtype = torch.float16): def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): - if max_seq_len == 2048: - return self.chunked_prefill_attn_mask.to(torch.bool) self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(device, non_blocking=True) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 32c2dc033e6..1d9139c5113 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -31,14 +31,7 @@ get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv - +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec @@ -49,9 +42,9 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.attention import vanilla_chunked_prefill -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, - nd_to_nz_2d, nd_to_nz_spec, - prefill_context_parallel_enable, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, + aligned_16, get_ascend_device_type, nd_to_nz_2d, + nd_to_nz_spec, prefill_context_parallel_enable, weak_ref_tensors) # isort: off @@ -63,22 +56,22 @@ # isort: on +from vllm.attention.backends.registry import (AttentionBackendEnum, + register_backend) + +@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_name() -> str: - return "ASCEND" + return "CUSTOM" @staticmethod def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: return AscendAttentionBackendImpl - @staticmethod - def get_metadata_cls() -> Type["AscendMetadata"]: - return AscendMetadata - @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: return AscendAttentionMetadataBuilder @@ -90,7 +83,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: return (2, num_blocks, num_kv_heads * head_size // 16, block_size, 16) return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -170,10 +163,12 @@ class ChunkedContextMetadata: actual_chunk_seq_lengths: list[int] actual_seq_lengths_kv: list[int] starts: torch.Tensor + chunk_seq_mask_filtered_indices: torch.Tensor chunked_req_mask: Optional[list[bool]] = None local_context_lens_allranks: Optional[list[list[int]]] = None cp_kv_recover_idx_for_chunk: Optional[list[int]] = None kv_inverse_idx_for_chunk: Optional[list[int]] = None + batch_chunk_seq_mask: Optional[list[bool]] = None """ Prefill Specific Metadata for Ascend""" pcp_metadata: Optional[AscendPCPMetadata] = None @@ -356,7 +351,7 @@ def build( query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: if attn_state == AscendAttentionState.PrefillNoCache: mask_nz = nd_to_nz_2d(attn_mask) attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), @@ -405,6 +400,14 @@ def build( cp_kv_recover_idx_for_chunk.to(torch.float32) ) if cp_kv_recover_idx_for_chunk is not None else None + batch_chunk_seq_mask = ( + local_context_lens_allranks[:, self.pcp_rank, + self.dcp_rank] == 0) + batch_chunk_seq_mask = torch.repeat_interleave( + batch_chunk_seq_mask, + repeats=(query_lens * self.pcp_size).to(self.device)) + chunk_seq_mask_filtered_indices = filter_chunked_req_indices( + query_lens, chunked_req_mask).to(self.device) chunked_context_metadata = \ AscendMetadataForPrefill.ChunkedContextMetadata( actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), @@ -413,7 +416,9 @@ def build( starts=local_chunk_starts, local_context_lens_allranks=local_context_lens_allranks, cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, - kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk + kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, + batch_chunk_seq_mask=batch_chunk_seq_mask, + chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices ) attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens @@ -575,10 +580,15 @@ def full_graph_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens=0): - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if self.pcp_size * self.dcp_size > 1: + intermediate_output = self._forward_pcp_dcp( + query, key, value, kv_cache, attn_metadata, output) + return intermediate_output, query.shape[0] + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None actual_seq_lengths_kv = attn_metadata.query_start_loc_list @@ -692,7 +702,7 @@ def _forward_prefill_no_cache( mask = attn_metadata.attn_mask - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: # align q k v output tensors query = aligned_16(query) key = aligned_16(key) @@ -773,7 +783,7 @@ def _forward_decode_only( attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: # seq_lens_tensor needs to be transferred to the device for 310P. attn_metadata.seq_lens = \ attn_metadata.seq_lens.to(device=query.device) @@ -847,7 +857,7 @@ def _forward_v1_style( assert attn_metadata is not None assert attn_metadata.attn_mask is not None - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: # Do reformat in case of broadcasted tensors. attn_metadata.attn_mask = \ torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), @@ -1280,9 +1290,7 @@ def _update_chunk_attn_out_lse_with_current_attn_out_lse( self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape - seq_len = attn_metadata.query_lens.detach().clone() - filtered_indices = filter_chunked_req_indices( - seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask) + filtered_indices = attn_metadata.prefill.chunked_context.chunk_seq_mask_filtered_indices attn_output_prefill_filtered = current_attn_output_prefill[ filtered_indices, :, :] @@ -1326,9 +1334,11 @@ def _compute_prefill_context(self, query: torch.Tensor, local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, self.dcp_rank] + total_toks = local_chunked_kv_lens_rank.sum() key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, - local_chunked_kv_lens_rank, query) + local_chunked_kv_lens_rank, query, + total_toks) if self.dcp_size > 1: num_heads = self.num_heads * self.dcp_size else: @@ -1344,7 +1354,7 @@ def _compute_prefill_context(self, query: torch.Tensor, dtype=torch.float32, device=query.device) - if not torch.all(local_chunked_kv_lens_rank == 0).item(): + if total_toks > 0: prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( query, key, @@ -1362,6 +1372,14 @@ def _compute_prefill_context(self, query: torch.Tensor, actual_seq_lengths_kv, actual_seq_lengths=attn_metadata.prefill.chunked_context. actual_chunk_seq_lengths) + batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask + out_mask = batch_chunk_seq_mask[:, None, None].expand_as( + prefix_chunk_output) + prefix_chunk_output = torch.where(out_mask, 0, prefix_chunk_output) + lse_mask = batch_chunk_seq_mask[:, None, + None].expand_as(prefix_chunk_lse) + prefix_chunk_lse = torch.where(lse_mask, -torch.inf, + prefix_chunk_lse) prefix_output, prefix_lse = self._update_chunk_attn_out_lse( prefix_chunk_output, prefix_chunk_lse) @@ -1417,14 +1435,12 @@ def _update_chunk_attn_out_lse(self, prefix_chunk_output, return prefix_output, prefix_lse def _load_kv_for_chunk(self, attn_metadata, kv_cache, - local_chunked_kv_lens_rank, query): + local_chunked_kv_lens_rank, query, total_toks): cache_key = kv_cache[0] cache_value = kv_cache[1] num_heads = cache_key.size(2) head_size = kv_cache[0].size(-1) - total_toks = local_chunked_kv_lens_rank.sum() - key = torch.empty(total_toks, num_heads, head_size, @@ -1583,7 +1599,7 @@ def forward( query, attn_metadata, output) else: intermediate_output, num_tokens = self.full_graph_attention( - query, key, value, attn_metadata, output) + query, key, value, kv_cache, attn_metadata, output) output[:num_tokens] = intermediate_output[:num_tokens] return output diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 62ed95c5938..188e66a5948 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -7,9 +7,8 @@ import torch.distributed as dist import torch_npu from torch import nn -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, @@ -21,14 +20,7 @@ from vllm.logger import logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv, round_down -else: - from vllm.utils.math_utils import cdiv, round_down - +from vllm.utils.math_utils import cdiv, round_down from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend import envs @@ -40,6 +32,7 @@ trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, + get_mtp_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod @@ -69,10 +62,6 @@ class AscendMLABackend(AttentionBackend): def get_name() -> str: return "ASCEND_MLA" - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AscendMLAMetadata - @staticmethod def get_builder_cls(): return AscendMLAMetadataBuilder @@ -343,6 +332,74 @@ def reorder_batch(self, input_batch: "InputBatch", # better way of doing this return modified_batch + def pad_actual_seq_len_q_mtp_enable_pad(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, + common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + return actual_seq_lengths_q + + def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q): + """ + Only use for acl full graph mode. + Pad the last element of the actual_seq_lengths_q equal to the TND(T) and + the num of dimensions equal to the batch_size of main model. + + For example: + batch_size = 8, num_reqs = 4, num_speculative_tokens = 1 + input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token) + After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8] + """ + need_padding = num_reqs_pad_size > 0 + if need_padding: + start_val = actual_seq_lengths_q[-1] + end_val = num_reqs + num_reqs_pad_size + num_step = num_reqs_pad_size + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == num_reqs_pad_size + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + return actual_seq_lengths_q + def build( self, common_prefix_len: int, @@ -368,11 +425,25 @@ def build( # it blocks on all previous kernels. device = self.device - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - + # If graph_pad_size > -1, mean is running in fullgraph mode. + graph_pad_size = common_attn_metadata.graph_pad_size + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch: + block_table = ( + common_attn_metadata.block_table_tensor[:graph_pad_size]) + else: + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + # NOTE: Currently, MTP-fullgraph is incompatibility pcp + if self.pcp_size > 1: + num_decodes_flatten = num_decodes * self.decode_threshold + block_table = common_attn_metadata.block_table_tensor[: + num_decodes_flatten + + + num_prefills] if num_actual_tokens_pcp_padded is None: num_actual_tokens_pcp_padded = num_actual_tokens + # NOTE: Currently, MTP-fullgraph is incompatibility pcp slot_mapping = common_attn_metadata.slot_mapping[: num_actual_tokens_pcp_padded] input_positions = common_attn_metadata.positions[: @@ -546,6 +617,9 @@ def build( cos=cos, pcp_metadata=pcp_metadata, ) + if self.pcp_size > 1: + prefill_metadata.block_table = block_table[ + num_decodes_flatten:, ...] decode_metadata = None if num_decodes > 0: @@ -556,12 +630,17 @@ def build( max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes, ...] - # For pcp + spec decode, we flatten seq_lens and block_table - # to avoid irregular spec_attn_mask shape - if self.pcp_size > 1 and self.decode_threshold > 1: - block_table = block_table.repeat_interleave( - self.decode_threshold, dim=0) + if self.pcp_size > 1: + # For pcp + spec decode, we flatten seq_lens and block_table + # to avoid irregular spec_attn_mask shape + block_table = block_table[:num_decodes_flatten, ...] + else: + block_table = block_table[:num_decodes, ...] + # NOTE: Currently, MTP-fullgraph is incompatibility pcp + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if graph_pad_size > num_decodes and \ + self.speculative_config.disable_padded_drafter_batch: + block_table = block_table[:graph_pad_size, ...] seq_lens_list = seq_lens.tolist() if num_computed_tokens_of_pcp_dcp is not None: @@ -583,6 +662,52 @@ def build( else: cp_seq_len, batch_seq_mask = None, None + if graph_pad_size > num_reqs: + if self.speculative_config.disable_padded_drafter_batch: + num_reqs_pad_size = graph_pad_size - num_reqs + actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q) + seq_lens_list = seq_lens_list + [0] * (graph_pad_size - \ + num_decodes) + num_block_pad_size = graph_pad_size - block_table.shape[0] + if num_block_pad_size > 0: + block_table_padding = torch.zeros( + (num_block_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat( + [block_table, block_table_padding], dim=0) + else: + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + num_block_table_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - + num_decodes) + seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_block_table_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + position_padding = torch.zeros( + num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + common_attn_metadata) + # TODO: After the fullgraph supports MTP, the if branch needs to deleted assert self.cos_cache is not None assert self.sin_cache is not None @@ -1264,8 +1389,11 @@ def _forward_decode( "actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths_kv": decode_meta.seq_lens_list, } - graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() + if forward_context.is_mtp_model: + graph_params = get_mtp_graph_params() + else: + graph_params = get_graph_params() if forward_context.capturing: stream = torch_npu.npu.current_stream() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 9747c2d1bc2..874ee39286e 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -4,9 +4,7 @@ import torch import torch_npu from torch import nn -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, @@ -35,10 +33,6 @@ class AscendSFABackend(AttentionBackend): def get_name() -> str: return "ASCEND_SFA" - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AscendSFAMetadata - @staticmethod def get_builder_cls(): return AscendSFAMetadataBuilder diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 3cb0613f4e5..025ff3c12ca 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -62,11 +62,9 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - graph_pool: Any = None, cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config - self.graph_pool = graph_pool self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -76,8 +74,7 @@ def __init__(self, # assert runtime_mode is not NONE(no aclgraph), otherwise, we don't # need to initialize a ACLGraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE - if self.graph_pool is None: - self.graph_pool = current_platform.get_global_graph_pool() + self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() @@ -186,6 +183,12 @@ def __call__(self, *args, **kwargs): f"got {new_input_addresses}") logger.info_once("Replaying aclgraph") + # In async scheduling or multi-threaded (MT) scenarios, it is possible that + # the CPU's record event (from update_attn_params) for the iteration i completes + # before the grph replay of iteration i-1. + # To ensure proper ordering, we must call synchronize here before replaying, + # so that update_attn_params only executes after the previous graph replay has fully completed. + torch.npu.synchronize() entry.aclgraph.replay() return entry.output @@ -235,7 +238,10 @@ def update_attn_params(update_stream, forward_context, runtime_shape): def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): - graph_params = get_graph_params() + if forward_context.is_mtp_model: + graph_params = get_mtp_graph_params() + else: + graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. with torch.npu.stream(update_stream): @@ -251,7 +257,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, softmax_lse) = param seq_lens_list = forward_context.attn_metadata[ key].decode.seq_lens_list - if speculative_config and speculative_config.method == "deepseek_mtp": + if speculative_config and speculative_config.method == "deepseek_mtp" \ + and not forward_context.is_mtp_model: actual_seq_lengths = forward_context.attn_metadata[ key].decode.actual_seq_lengths_q spec_multiple = speculative_config.num_speculative_tokens + 1 @@ -261,6 +268,13 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple) ] + elif forward_context.is_mtp_model: + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + block_table = forward_context.attn_metadata[ + key].decode.block_table + seq_lens_list = seq_lens_list + [0] * ( + len(actual_seq_lengths) - len(seq_lens_list)) else: seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list)) @@ -437,3 +451,32 @@ def update_graph_params_workspaces(num_tokens: int, workspace: int): def get_graph_params(): return _graph_params + + +_mtp_graph_params: Optional[GraphParams] = None + + +def set_mtp_graph_params(aclgraph_capture_sizes: set[int]): + global _mtp_graph_params + if _mtp_graph_params is not None: + raise ValueError("MTPGraph parameters have already been set!") + _mtp_graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any): + global _mtp_graph_params + if _mtp_graph_params is not None: + _mtp_graph_params.workspaces[num_tokens] = workspace + + +def get_mtp_graph_params(): + return _mtp_graph_params diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index 14a5d273959..49fd41da682 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -55,8 +55,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import ConstantList -from vllm_ascend.utils import vllm_version_is - class RecomputeScheduler(SchedulerInterface): """This Scheduler extends vllm's original v1 scheduler of version 0.11 @@ -94,7 +92,7 @@ def __init__( self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = \ self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.scheduler_config.max_model_len + self.max_model_len = self.vllm_config.model_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events) @@ -587,14 +585,9 @@ def schedule(self) -> RecomputeSchedulerOutput: self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] - if vllm_version_is("0.11.0"): - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) - else: - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id)) # Construct the scheduler output. new_reqs_data = [ @@ -935,8 +928,9 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids: list[int] = ( + sampled_token_ids[req_index].tolist() + if sampled_token_ids else []) scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 5f02567f7ff..800536d1568 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -22,14 +22,7 @@ from vllm.distributed.kv_events import KVEventBatch from vllm.logger import logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv - +from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler @@ -39,8 +32,6 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from vllm_ascend.utils import vllm_version_is - class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler @@ -71,14 +62,9 @@ def __init__( log_stats: bool = False, ) -> None: # Call the parent class's __init__ method - if vllm_version_is("0.11.0"): - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, mm_registry, - include_finished_set, log_stats) - else: - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, block_size, - mm_registry, include_finished_set, log_stats) + super().__init__(vllm_config, kv_cache_config, + structured_output_manager, block_size, mm_registry, + include_finished_set, log_stats) # Initialize common attributes self._initialize_common() @@ -233,7 +219,8 @@ def skip_cur_request(): # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( + new_encoder_budget, + _) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, encoder_budget) if num_new_tokens == 0 or len( @@ -462,14 +449,9 @@ def skip_cur_request(): self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] - if vllm_version_is("0.11.0"): - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) - else: - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id)) # Construct the scheduler output. new_reqs_data = [ @@ -483,7 +465,6 @@ def skip_cur_request(): num_scheduled_tokens, scheduled_spec_decode_tokens, req_to_new_blocks) scheduled_cached_reqs = cached_reqs_data - scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=scheduled_cached_reqs, @@ -499,10 +480,7 @@ def skip_cur_request(): finished_req_ids=self.finished_req_ids, # type: ignore free_encoder_mm_hashes=self.encoder_cache_manager. get_freed_mm_hashes(), - structured_output_request_ids={}, - grammar_bitmask=None, ) - # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object @@ -558,10 +536,10 @@ def _check_watermark_for_prefill(self, def _get_prompt_limit(self, request: Request) -> int: if (self.scheduler_config.chunked_prefill_enabled and not self.scheduler_config.is_multi_step): - prompt_limit = self.scheduler_config.max_model_len + prompt_limit = self.vllm_config.model_config.max_model_len else: prompt_limit = min( - self.scheduler_config.max_model_len, + self.vllm_config.model_config.max_model_len, self.scheduler_config.max_num_batched_tokens, ) diff --git a/vllm_ascend/core/scheduler_dynamic_batch.py b/vllm_ascend/core/scheduler_dynamic_batch.py index 6e984a22976..e731bb21eb1 100644 --- a/vllm_ascend/core/scheduler_dynamic_batch.py +++ b/vllm_ascend/core/scheduler_dynamic_batch.py @@ -33,11 +33,9 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from vllm_ascend.utils import vllm_version_is - class BudgetRefiner: - """This budget refiner can make dynamic adjustment to the token budget + """This budget refiner can make dynamic adjustment to the token budget in the chunked prefill scheduling strategy.""" def __init__(self, default_budget, slo_limit=-1) -> None: @@ -130,14 +128,9 @@ def __init__( include_finished_set: bool = False, log_stats: bool = False, ) -> None: - if vllm_version_is("0.11.0"): - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, mm_registry, - include_finished_set, log_stats) - else: - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, block_size, - mm_registry, include_finished_set, log_stats) + super().__init__(vllm_config, kv_cache_config, + structured_output_manager, block_size, mm_registry, + include_finished_set, log_stats) self.running: list[Request] = [] self.budget_refiner = BudgetRefiner( default_budget=self.scheduler_config.max_num_batched_tokens, @@ -423,8 +416,8 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( + new_encoder_compute_budget, + _) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, encoder_compute_budget) if num_new_tokens == 0: @@ -540,14 +533,9 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] - if vllm_version_is("0.11.0"): - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) - else: - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id)) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request( @@ -561,11 +549,6 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -581,8 +564,6 @@ def schedule(self) -> SchedulerOutput: finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager. get_freed_mm_hashes(), - structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, ) # NOTE(Kuntai): this function is designed for multiple purposes: diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 0915b38a519..04195d1cc5b 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -31,8 +31,13 @@ def register_connector(): KVConnectorFactory.register_connector( "MooncakeConnectorStoreV1", - "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", - "MooncakeConnectorV1") + "vllm_ascend.distributed.kvpool.ascend_store_connector", + "AscendStoreConnector") + + KVConnectorFactory.register_connector( + "AscendStoreConnector", + "vllm_ascend.distributed.kvpool.ascend_store_connector", + "AscendStoreConnector") KVConnectorFactory.register_connector( "MooncakeLayerwiseConnector", diff --git a/vllm_ascend/distributed/cpu_offload_connector.py b/vllm_ascend/distributed/cpu_offload_connector.py index 2e91f715232..c6983b69e23 100644 --- a/vllm_ascend/distributed/cpu_offload_connector.py +++ b/vllm_ascend/distributed/cpu_offload_connector.py @@ -29,6 +29,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -58,7 +59,10 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata): class CPUOffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): if not vllm_config.cache_config.enable_prefix_caching: self.connector_scheduler: Optional[ CPUOffloadingConnectorScheduler] = None diff --git a/vllm_ascend/distributed/cpu_offload_manager/metadata.py b/vllm_ascend/distributed/cpu_offload_manager/metadata.py index 7f07a624238..b89659e2a1d 100644 --- a/vllm_ascend/distributed/cpu_offload_manager/metadata.py +++ b/vllm_ascend/distributed/cpu_offload_manager/metadata.py @@ -10,17 +10,12 @@ import zmq from vllm.config import KVTransferConfig, VllmConfig from vllm.utils import logger +from vllm.utils.network_utils import make_zmq_socket +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \ CPUKVCacheManager -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import get_dtype_size, make_zmq_socket -else: - from vllm.utils.network_utils import make_zmq_socket - from vllm.utils.torch_utils import get_dtype_size @dataclass diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 3aa491317cc..5c5a0a5bef3 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -28,22 +28,19 @@ from vllm.utils import logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version, - prefill_context_parallel_enable, - vllm_version_is) +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + prefill_context_parallel_enable) if prefill_context_parallel_enable(): from vllm.distributed.parallel_state import \ get_prefill_context_model_parallel_rank -if vllm_version_is("0.11.0"): - from vllm.utils import get_ip -else: - from vllm.utils.network_utils import get_ip +from vllm.utils.network_utils import get_ip TORCH_DTYPE_TO_NPU_DTYPE = { torch.half: llm_datadist.DataType.DT_FLOAT16, @@ -104,7 +101,10 @@ def add_new_req(self, request_id: str, local_block_ids: list[int], class LLMDataDistCMgrConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: @@ -380,7 +380,7 @@ def __init__(self, vllm_config: VllmConfig): self.local_agent_metadata.cluster_id) self.init_llm_datadist() self.finished_reqs: set[str] = set() - self.soc_info = get_ascend_soc_version() + self.soc_info = get_ascend_device_type() # Set hccl deterministic for model execute os.environ["HCCL_DETERMINISTIC"] = "true" self.done_receiving_counts: defaultdict[str, @@ -765,7 +765,7 @@ def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: rank_table["server_list"].append( # type: ignore[attr-defined] decode_server_device_info) - if self.soc_info == AscendSocVersion.A3: + if self.soc_info == AscendDeviceType._910_93: # generate super_pod_list for rank table super_pod_list = [] prefill_super_pod_info = { diff --git a/vllm_ascend/distributed/mooncake/__init__.py b/vllm_ascend/distributed/mooncake/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py deleted file mode 100644 index 2d2505aba14..00000000000 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ /dev/null @@ -1,561 +0,0 @@ -import array -import hashlib -import json -import os -import re -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union - -import torch -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata -from vllm.utils import logger - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv - -from vllm.v1.core.sched.output import NewRequestData - -DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB - - -@dataclass -class MooncakeEngineMetadata: - """name of the LLM model""" - - model_name: str - """ world size when running under a distributed setting """ - world_size: int - """ Initialize the current prefill context model parallel rank """ - pcp_rank: int - """ Initialize the current decode context model parallel rank """ - dcp_rank: int - """ Initialize the current tensor parallel rank """ - tp_rank: int - """ the format of kv tensors """ - kv_dtype: torch.dtype - """ the shape of kv tensors """ - """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ - kv_shape: tuple[int, int, int, int, int] - block_size: int = 128 - """ whether use MLA""" - use_mla: bool = False - - -@dataclass(order=True) -class MooncakeEngineKey: - model_name: str - world_size: int - pcp_rank: int - dcp_rank: int - tp_rank: int - chunk_hash: str - - def __hash__(self): - return hash(( - self.model_name, - self.world_size, - self.pcp_rank, - self.dcp_rank, - self.tp_rank, - self.chunk_hash, - )) - - def to_string(self): - return (f"{self.model_name}@{self.world_size}" - f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}" - f"@tp{self.tp_rank}@{self.chunk_hash}") - - def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: - """Split the key into multiple keys for each layer""" - keys = [] - for layer_id in range(num_layers): - keys.append( - LayerMooncakeEngineKey( - self.model_name, - self.world_size, - self.pcp_rank, - self.dcp_rank, - self.tp_rank, - self.chunk_hash, - layer_id, - )) - return keys - - def to_dict(self): - # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. - return { - "__type__": "CacheEngineKey", - "model_name": self.model_name, - "world_size": self.world_size, - "pcp_rank": self.pcp_rank, - "dcp_rank": self.dcp_rank, - "tp_rank": self.tp_rank, - "chunk_hash": self.chunk_hash, - } - - @staticmethod - def from_dict(d): - return MooncakeEngineKey( - model_name=d["model_name"], - world_size=d["world_size"], - pcp_rank=d["pcp_rank"], - dcp_rank=d["dcp_rank"], - tp_rank=d["tp_rank"], - chunk_hash=d["chunk_hash"], - ) - - -@dataclass(order=True) -class LayerMooncakeEngineKey(MooncakeEngineKey): - """A key for the layer cache engine""" - - layer_id: int - - def __hash__(self): - return hash(( - self.model_name, - self.world_size, - self.pcp_rank, - self.dcp_rank, - self.tp_rank, - self.chunk_hash, - self.layer_id, - )) - - def to_string(self): - return (f"{self.model_name}@{self.world_size}" - f"@pcp{self.pcp_rank}@dcp{self.dcp_rank}" - f"@tp{self.tp_rank}@{self.chunk_hash}@{self.layer_id}") - - -class ChunkedTokenDatabase(): - - def __init__( - self, - metadata: MooncakeEngineMetadata, - ): - self.metadata = metadata - - def _make_key_by_hash(self, - chunk_hash: str, - layer_id: Optional[int] = None): - assert self.metadata is not None - return MooncakeEngineKey( - self.metadata.model_name, - self.metadata.world_size, - self.metadata.pcp_rank, - self.metadata.dcp_rank, - self.metadata.tp_rank, - chunk_hash, - ) - - def _hash( - self, - tokens: Union[torch.Tensor, List[int]], - prefix_hash: str, - ) -> str: - # TODO: change it to a more efficient hash function - if isinstance(tokens, torch.Tensor): - tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() - elif isinstance(tokens, list): - tokens_bytes = array.array("I", tokens).tobytes() - return hashlib.sha256(prefix_hash.encode("ascii") + - tokens_bytes).hexdigest() - - def _chunk_tokens( - self, - tokens: Union[torch.Tensor, List[int]], - ) -> Iterable[Union[torch.Tensor, List[int]]]: - """ - Chunk the tokens into chunks of size self.metadata.block_size. - - :param tokens: the input tokens, with shape [seq_len] - device: the target device after chunking - - :return: a generator of chunks of tokens, each with - shape [metadata.block_size] - """ - for i in range(0, len(tokens), self.metadata.block_size): - yield tokens[i:i + self.metadata.block_size] - - def _prefix_hash( - self, - token_chunks: Iterable[Union[torch.Tensor, List[int]]], - ) -> Iterable[str]: - prefix_hash = '' - for token_chunk in token_chunks: - prefix_hash = self._hash(token_chunk, prefix_hash) - yield prefix_hash - - def process_tokens( - self, - tokens: Union[torch.Tensor, List[int]], - mask: Optional[torch.Tensor] = None, - ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: - """Process the tokens and return the corresponding cache engine keys. - - :param Union[torch.Tensor, List[int]] tokens: The tokens to process. - - :param Optional[torch.Tensor] mask: The mask for the tokens. Should - have the same length as tokens. And the mask should ALWAYS be like - FFFFFTTTTTTT, where True means the tokens needs to be matched, - and the Falses will ALWAYS be at the PREFIX of the tensor. - - :param bool make_key: Whether to make the cache engine key or not. - If False, the hash value will be returned instead. - - :returns: A iterable of tuples with three elements. The first element - is the start index of the tokens for the key. The second element - is the end index of the tokens for the key. The third element is - the cache engine key (or hash) for the tokens. - - :raises: ValueError if the number of Falses in the mask is not a - multiple of the chunk size. - """ - if mask is not None: - num_falses = mask.numel() - mask.long().sum().item() - else: - num_falses = 0 - - if num_falses % self.metadata.block_size != 0: - raise ValueError( - "The number of Falses in the mask is not a multiple of the chunk size." - ) - total_len = len(tokens) - - token_chunks = self._chunk_tokens(tokens) - prefix_hashes = self._prefix_hash(token_chunks) - - start_idx = 0 - for chunk_id, hash_val in enumerate(prefix_hashes): - start_idx = chunk_id * self.metadata.block_size - end_idx = min(start_idx + self.metadata.block_size, total_len) - if start_idx < num_falses: - continue - else: - yield start_idx, end_idx, self._make_key_by_hash(hash_val) - - -@dataclass -class LoadSpec: - # Number of tokens cached in vLLM - vllm_cached_tokens: int - # Number of tokens that are cached in mooncake - mooncake_cached_tokens: int - # Whether the scheduler allow us to load the tokens - can_load: bool - - -@dataclass -class SaveSpec: - # Skip already saved tokens - skip_leading_tokens: int - # Whether the scheduler allow us to save the tokens - can_save: bool - - -@dataclass -class RequestTracker: - # Request id - req_id: str - - # The token ids that has been scheduled so far - token_ids: list[int] - - # The block ids that has been allocated so far - # NOTE: allocated blocks could be more than the number of tokens - # FIXME: need to check whether the block ids will be changed after - # preemption - allocated_block_ids: list[int] - - # The number of tokens that has been savd - num_saved_tokens: int = 0 - - @staticmethod - def from_new_request( - new_request: "NewRequestData", - num_tokens_to_compute: int, - ) -> "RequestTracker": - """Create the request tracker from a new request. - - Args: - new_request (NewRequestData): the new request data. - num_tokens_to_compute (int): the number of tokens that will - be 'computed', including the `num_computed_tokens` (vLLM's - local cache hit) and new tokens that will be scheduled. - - """ - # vLLM 0.9.0 update: request.block_ids changed from list[int] to - # list[list[int]] - # Need to check the type of request.block_ids - - unfolded_block_ids = [] - - if not isinstance(new_request.block_ids[0], list): - unfolded_block_ids = new_request.block_ids.copy() - else: - unfolded_block_ids = new_request.block_ids[0].copy() - - return RequestTracker( - req_id=new_request.req_id, - token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. - copy(), - allocated_block_ids=unfolded_block_ids, - num_saved_tokens=0, - ) - - def update( - self, - new_token_ids: list[int], - new_block_ids: Union[tuple[list[int], ...], list[int]], - ) -> None: - """Update the request tracker when a running request is - scheduled again - """ - - self.token_ids.extend(new_token_ids) - - if len(new_block_ids) == 0: - new_block_ids = [] - elif isinstance(new_block_ids, tuple): - new_block_ids = new_block_ids[0] - elif isinstance(new_block_ids, list): - pass - else: - raise ValueError( - f"Unsupported new_block_ids type {type(new_block_ids)}") - self.allocated_block_ids.extend(new_block_ids) - - -@dataclass -class ReqMeta: - # Request id - req_id: str - # Request tokens - token_ids: torch.Tensor - - block_ids: list[int] - # # Slot mapping if exchange for block_id - # slot_mapping: torch.Tensor - # Skip save or not - save_spec: Optional[SaveSpec] = None - # load_spec - load_spec: Optional[LoadSpec] = None - - is_last_chunk: Optional[bool] = None - - @staticmethod - def from_request_tracker( - tracker: RequestTracker, - block_size: int, - load_spec: Optional[LoadSpec] = None, - skip_save: Optional[bool] = False, - is_last_chunk: Optional[bool] = None, - discard_partial_chunks: bool = True, - ) -> Optional["ReqMeta"]: - """Create the request metadata from a request tracker. - - Args: - tracker (RequestTracker): the request tracker. - block_size (int): the block size in vLLM. - load_spec (Optional[LoadSpec]): the load spec for KV cache loading. - skip_save (bool): whether to skip the save operation. - discard_partial_chunks (bool): whether to discard partial chunks. - - Returns: - the request metadata if we need to perform load/save - operations, None otherwise. - """ - input_token_ids = tracker.token_ids - input_token_len = len(input_token_ids) - - # For save operation: do not save if the following condition is met - # 1. has already been saved before (num_saved_tokens > 0) - # 2. number of unsaved tokens is not reached the chunk boundary - skip_leading_tokens = tracker.num_saved_tokens - chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * - block_size if discard_partial_chunks else 0) - # Calculate number of tokens to save based on discard_partial_chunks - # setting - num_tokens_to_save = ((input_token_len // block_size * block_size) - if discard_partial_chunks else input_token_len) - - skip_save = skip_save or num_tokens_to_save < chunk_boundary - if skip_save and load_spec is None: - return None - - # If we need to save, update the number of saved tokens - if not skip_save: - tracker.num_saved_tokens = num_tokens_to_save - save_spec = SaveSpec(skip_leading_tokens, not skip_save) - - # Calculate the token ids and slot mappings for load and save - # OPTIMIZATION: pre-allocate the buffer for token ids and block ids - token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] - - # # For load operation: check whether the request is scheduled to load - if load_spec is not None and load_spec.can_load: - logger.debug( - "Scheduled to load %d tokens for request %s", - load_spec.mooncake_cached_tokens, - tracker.req_id, - ) - else: - # Do not load if not in `can_load` state - load_spec = None - logger.debug( - f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" - ) - return ReqMeta( - req_id=tracker.req_id, - token_ids=token_ids, - block_ids=tracker.allocated_block_ids, - save_spec=save_spec, - load_spec=load_spec, - is_last_chunk=is_last_chunk, - ) - - -class MooncakeConnectorMetadata(KVConnectorMetadata): - - def __init__(self, unfinished_request_ids): - self.requests = [] - self.unfinished_request_ids = unfinished_request_ids - - def add_request(self, req_meta: ReqMeta) -> None: - """Add a request to the metadata. - - Args: - req_meta (ReqMeta): the request metadata. - """ - self.requests.append(req_meta) - - -@dataclass -class LasyerMultiBlockReqMeta: - req_id: str - keys: List[LayerMooncakeEngineKey] - starts: List[int] - ends: list[int] - block_ids: list[int] - layer_id: int - - -@dataclass -class MooncakeStoreConfig: - local_hostname: str - metadata_server: str - global_segment_size: Union[int, str] - local_buffer_size: int - protocol: str - device_name: str - master_server_address: str - use_ascend_direct: bool - - @staticmethod - def from_file(file_path: str) -> "MooncakeStoreConfig": - with open(file_path) as file: - config = json.load(file) - return MooncakeStoreConfig( - local_hostname=config.get("local_hostname"), - metadata_server=config.get("metadata_server"), - global_segment_size=_parse_global_segment_size( - config.get("global_segment_size", - DEFAULT_GLOBAL_SEGMENT_SIZE)), - local_buffer_size=(config.get("local_buffer_size", - DEFAULT_LOCAL_BUFFER_SIZE)), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address"), - use_ascend_direct=config.get("use_ascend_direct", False)) - - @staticmethod - def load_from_env() -> "MooncakeStoreConfig": - config_path = os.getenv("MOONCAKE_CONFIG_PATH") - if not config_path: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeStoreConfig.from_file(config_path) - - -def _parse_global_segment_size(value) -> int: - """ - Parse storage size strings with support for units: GB, MB, KB, B - - Args: - value: Input value (int, str, or other convertible types) - - Returns: - int: Size in bytes - - Raises: - ValueError: For invalid format, missing number, or negative values - TypeError: For unsupported input types - """ - - if isinstance(value, int): - return value - elif not isinstance(value, str): - try: - return int(value) - except (TypeError, ValueError) as e: - raise TypeError( - f"Unsupported type for global_segment_size: {type(value)}" - ) from e - - cleaned_input = value.strip().lower() - if not cleaned_input: - raise ValueError("global segment size cannot be empty.") - - UNIT_MULTIPLIERS = { - 'gb': 1024**3, # 1 GB = 1024^3 bytes - 'mb': 1024**2, # 1 MB = 1024^2 bytes - 'kb': 1024, # 1 KB = 1024 bytes - 'b': 1 # 1 B = 1 byte - } - pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' - match = re.match(pattern, cleaned_input) - - if not match: - raise ValueError(f"Invalid format: '{value}'") - - number_str = match.group(1) - unit = match.group(2) or 'b' - - multiplier = UNIT_MULTIPLIERS[unit] - return _convert_to_bytes(number_str, multiplier, value) - - -def _convert_to_bytes(number_str: str, multiplier: int, - original_input: str) -> int: - """ - Convert numeric string to byte count - - Args: - number_str: Numeric portion of input - multiplier: Unit conversion factor - original_input: Original input string (for error messages) - - Returns: - int: Byte count - - Raises: - ValueError: For invalid numbers or negative results - """ - try: - numeric_value = float(number_str) - except ValueError: - raise ValueError( - f"Invalid numeric value '{number_str}' in: '{original_input}'") - # Calculate byte count - try: - byte_count = int(numeric_value * multiplier) - except OverflowError: - raise ValueError(f"Storage size too large: '{original_input}'") - return byte_count diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py deleted file mode 100644 index 4472f678ddd..00000000000 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ /dev/null @@ -1,282 +0,0 @@ -import queue -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional - -import torch -from vllm.utils import logger - -from vllm_ascend.distributed.mooncake.config_data import ( - ChunkedTokenDatabase, LasyerMultiBlockReqMeta) -from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore - - -class KVTransferThread(threading.Thread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, name: str): - super().__init__(daemon=True, name=name) - self.tp_rank = tp_rank - self.tp_size = tp_size - self.m_store = m_store - self.ready_event = ready_event - self.kv_caches_base_addr = local_kv_caches_base_addr - self.block_len = block_len - self.token_database = token_database - self.block_size = block_size - self.done_task_lock = threading.Lock() - # TODO(jianzs): find a better way to detect MLA. - self.use_mla = len(block_len) == 2 - - self.request_queue: queue.Queue[Any] = queue.Queue() - # TODO(jianzs): make this configurable - self.executor = ThreadPoolExecutor(max_workers=32) - self.finished_requests: set[str] = set() - - def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list = [] - size_list = [] - block_id = block_ids[start // self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) - addr_list.append(addr) - size_list.append(length) - return addr_list, size_list, block_id - - def prepare_value_layer(self, start: int, end: int, block_ids: list[int], - layer_id: int): - block_id = block_ids[start // self.block_size] - if self.use_mla: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[1] - length_k = int(self.block_len[0] / self.block_size * (end - start)) - length_v = int(self.block_len[1] / self.block_size * (end - start)) - size_list = [length_k, length_v] - else: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[0] - length = int(self.block_len[0] / self.block_size * (end - start)) - size_list = [length, length] - addr_list = [addr_k, addr_v] - return addr_list, size_list - - def add_request( - self, - req_id: str, - tokens: torch.Tensor, - block_ids: list[int], - mask: Optional[torch.Tensor] = None, - is_last_chunk: Optional[bool] = None, - ) -> torch.Tensor: - req = ({ - "req_id": req_id, - "tokens": tokens, - "block_ids": block_ids, - "mask": mask, - "is_last_chunk": is_last_chunk, - }) - self.request_queue.put(req) - - def get_and_clear_finished_requests(self) -> set[str]: - """ - Get and clear the requests that have been completed. - Returns: - A set of request IDs that have been completed. - """ - with self.done_task_lock: - finished_requests = self.finished_requests.copy() - self.finished_requests.clear() - return finished_requests - - def set_finished_request(self, req_id): - with self.done_task_lock: - self.finished_requests.add(req_id) - - def run(self): - """Run the thread to handle KV cache transfer requests.""" - self.ready_event.set() - while True: - try: - request_data = self.request_queue.get() - if request_data is None: - logger.warning("Received a None request!") - self.request_queue.task_done() - continue - self._handle_request(request_data) - except Exception as e: - logger.error(f"Error in KVCacheTransferThread: {e}") - - def _handle_request(self, req_meta: dict[str, Any]): - pass - - -class KVCacheStoreSendingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheSendingThread") - - def _handle_request(self, req_meta: dict[str, Any]): - tokens = req_meta["tokens"] - mask = req_meta["mask"] - block_ids = req_meta["block_ids"] - req_id = req_meta["req_id"] - is_last_chunk = req_meta["is_last_chunk"] - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, block_id = self.prepare_value( - start, end, block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - torch.npu.current_stream().synchronize() - self.m_store.put_batch(key_list, addr_list, size_list, blockIds) - else: - torch.npu.current_stream().synchronize() - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.put(key, addr, size) - if is_last_chunk: - self.set_finished_request(req_id) - self.request_queue.task_done() - - -class KVCacheStoreRecvingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreRecvingThread") - - def _handle_request(self, req_meta: dict[str, Any]): - tokens = req_meta["tokens"] - mask = req_meta["mask"] - block_ids = req_meta["block_ids"] - req_id = req_meta["req_id"] - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, block_id = self.prepare_value( - start, end, block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, blockIds) - else: - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.get(key, addr, size) - self.set_finished_request(req_id) - self.request_queue.task_done() - - -class KVCacheStoreLayerSendingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, - num_layers: int): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreLayerSendingThread") - self.final_layer_id = num_layers - 1 - - def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: - self.request_queue.put(req_meta) - - def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): - torch.npu.current_stream().synchronize() - for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], - req_meta.ends[index], - req_meta.block_ids, - req_meta.layer_id) - self.m_store.put(key, addr, size) - if req_meta.layer_id == self.final_layer_id: - self.set_finished_request(req_meta.req_id) - self.request_queue.task_done() - - -class KVCacheStoreLayerRecvingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, - get_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreLayerRecvingThread") - self.get_event = get_event - - def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: - self.request_queue.put(req_meta) - - def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): - for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], - req_meta.ends[index], - req_meta.block_ids, - req_meta.layer_id) - self.m_store.get(key, addr, size) - self.request_queue.task_done() - self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py deleted file mode 100644 index ff2cbfd1894..00000000000 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ /dev/null @@ -1,652 +0,0 @@ -# Standard -import math -import threading -import time -from typing import Generator, List, Optional, Union - -# Third Party -import torch -from vllm.config import VllmConfig -from vllm.distributed import (get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.utils import logger - -from vllm_ascend.distributed.mooncake.config_data import ( - ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, - MooncakeEngineMetadata) -from vllm_ascend.distributed.mooncake.kv_transfer import ( - KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, - KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) -from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore -from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import get_kv_cache_torch_dtype -else: - from vllm.utils.torch_utils import get_kv_cache_torch_dtype - -if prefill_context_parallel_enable(): - from vllm.distributed import ( - get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size) - - -class MooncakeEngine: - #The main class for the cache engine. - - def __init__( - self, - vllm_config: VllmConfig, - use_layerwize: bool, - ): - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - self.use_mla = False - if (hasattr(model_config, "use_mla") - and isinstance(model_config.use_mla, bool) - and model_config.use_mla): - self.use_mla = True - self.use_layerwise = use_layerwize - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 - self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 - - self.kv_role = vllm_config.kv_transfer_config.kv_role - self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "load_async", False) - self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "register_buffer", False) - self.block_size = vllm_config.cache_config.block_size - - if self.pcp_size > 1: - self.block_size *= self.pcp_size - if self.dcp_size > 1: - self.block_size *= self.dcp_size - - self.current_layer = 0 - # self.use_mla = first_kv_cache_tuple[0].size( - # -1) != first_kv_cache_tuple[1].size(-1) - self.num_layers = model_config.get_num_layers(parallel_config) - num_kv_head = model_config.get_num_kv_heads(parallel_config) - head_size = model_config.get_head_size() - kv_dtype = get_kv_cache_torch_dtype( - vllm_config.cache_config.cache_dtype, model_config.dtype) - self.hidden_dim_size = num_kv_head * head_size - if self.use_mla: - kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) - else: - kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, - head_size) - self.metadata = MooncakeEngineMetadata( - model_config.model, - parallel_config.world_size, - self.pcp_rank, - self.dcp_rank, - self.tp_rank, - kv_dtype, - kv_shape, - self.block_size, - self.use_mla, - ) - - self.token_database = ChunkedTokenDatabase(self.metadata) - - self.m_store = Mooncakestore(parallel_config) - - self.kv_send_thread: Optional[KVTransferThread] = None - self.kv_recv_thread: Optional[KVTransferThread] = None - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - _, first_kv_cache_tuple = next(iter(kv_caches.items())) - first_kv_cache = first_kv_cache_tuple[0] - - # TODO(tms): Find a more robust way to detect and handle MLA - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe) - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, block_shape_norm, block_shape_pe) - else: - # [num_block, block_size, num_head, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, - block_shape) - - logger.info("Registering KV_Caches. use_mla: %s, shape %s", - self.use_mla, first_kv_cache.shape) - - self.kv_caches = kv_caches - self.kv_caches_base_addr = [] - for cache_or_caches in kv_caches.values(): - # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - self.kv_caches_base_addr.append(base_addr) - if self.register_buffer: - region_len = self.num_blocks * self.block_len[i % 2] - self._register(base_addr, region_len) - else: - cache_list = [cache_or_caches - ] if self.use_mla else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - self.kv_caches_base_addr.append(base_addr) - if self.register_buffer: - region_len = self.num_blocks * self.block_len[0] - self._register(base_addr, region_len) - - if self.use_layerwise: - self.get_event = threading.Event() - if self.kv_role in ['kv_producer', 'kv_both']: - ready_event_sending = threading.Event() - self.kv_send_thread = KVCacheStoreLayerSendingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event_sending, - self.num_layers) - self.kv_send_thread.start() - ready_event = threading.Event() - self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, - self.block_size, ready_event, self.get_event) - self.kv_recv_thread.start() - ready_event.wait() - else: - if self.kv_role in ['kv_producer', 'kv_both']: - ready_event_sending = threading.Event() - self.kv_send_thread = KVCacheStoreSendingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event_sending) - self.kv_send_thread.start() - if self.load_async: - ready_event = threading.Event() - self.kv_recv_thread = KVCacheStoreRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event) - self.kv_recv_thread.start() - ready_event.wait() - - def _register(self, ptr, length): - logger.debug( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - try: - self.m_store.register_buffer(ptr, length) - except Exception as e: - raise RuntimeError( - f"Mooncake memory registration failed. Error is: {e}") - - def start_load_kv(self, metadata: MooncakeConnectorMetadata): - self.current_layer = 0 - self.layerwise_retrievers = [] - for request in metadata.requests: - load_spec = request.load_spec - if load_spec is None or not load_spec.can_load: #load =0 - continue - tokens = request.token_ids - req_id = request.req_id - if (load_spec.mooncake_cached_tokens % self.block_size - != 0) and (load_spec.mooncake_cached_tokens - == tokens.shape[0] - 1): - tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] - else: - tokens = tokens[:request.load_spec.mooncake_cached_tokens] - masked_token_count = (request.load_spec.vllm_cached_tokens // - self.block_size * self.block_size) - token_mask = torch.ones_like(tokens, dtype=torch.bool) - token_mask[:masked_token_count] = False - if self.use_layerwise: - layerwise_retriever = self.retrieve_layer( - req_id, - tokens, - request.block_ids, - token_mask, - ) - next(layerwise_retriever) # first layer load - self.layerwise_retrievers.append(layerwise_retriever) - else: - if self.load_async: - self.kv_recv_thread.add_request( # type: ignore[union-attr] - req_id, - tokens, - request.block_ids, - token_mask, - ) - else: - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, token_mask): - addr, size, block_id = self.prepare_value( - start, end, request.block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, - blockIds) - else: - for start, end, key in self.token_database.process_tokens( - tokens, token_mask): - addr, size, _ = self.prepare_value( - start, end, request.block_ids) - self.m_store.get(key, addr, size) - - def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list = [] - size_list = [] - block_id = block_ids[start // self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) - addr_list.append(addr) - size_list.append(length) - return addr_list, size_list, block_id - - def wait_for_layer_load(self) -> None: - """MooncakeConnector does not do layerwise saving.""" - for layerwise_retriever in self.layerwise_retrievers: - ret_token_mask = next(layerwise_retriever) - if self.current_layer == self.num_layers - 1: - assert ret_token_mask is not None - num_retrieved_tokens = ret_token_mask.sum().item() - logger.info(f"Retrieved {num_retrieved_tokens} tokens") - - def save_kv_layer(self, - connector_metadata: MooncakeConnectorMetadata) -> None: - """MooncakeConnector does not save explicitly.""" - if self.current_layer == 0: - self.layerwise_storers = [] - for request in connector_metadata.requests: - save_spec = request.save_spec - if save_spec is None or not save_spec.can_save: - continue - - token_ids = request.token_ids - req_id = request.req_id - assert isinstance(token_ids, torch.Tensor) - assert token_ids.is_cpu - - # TODO: whether need to remov saveThread - # no lookup, skipmask - skip_leading_tokens = max( - self.lookup(token_ids, self.use_layerwise), - save_spec.skip_leading_tokens, - ) - if skip_leading_tokens == len(token_ids): - if request.is_last_chunk: - self.kv_send_thread.set_finished_request( # type: ignore[union-attr] - req_id) - continue # skip this request - - skip_leading_tokens = (skip_leading_tokens // self.block_size * - self.block_size) - - store_mask = torch.ones_like(token_ids, dtype=torch.bool) - store_mask[:skip_leading_tokens] = False - logger.info( - "Storing KV cache for %d out of %d tokens " - "(skip_leading_tokens=%d) for request %s", - len(token_ids) - skip_leading_tokens, - len(token_ids), - skip_leading_tokens, - request.req_id, - ) - - layerwise_storer = self.store_layer( - req_id, - token_ids, - mask=store_mask, - block_ids=request.block_ids, - ) - self.layerwise_storers.append(layerwise_storer) - for layerwise_storer in self.layerwise_storers: - try: - next(layerwise_storer) - except Exception: - raise - self.current_layer = self.current_layer + 1 - - def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): - """MooncakeConnector does not save explicitly.""" - for request in connector_metadata.requests: - save_spec = request.save_spec - if save_spec is None or not save_spec.can_save: - continue - - token_ids = request.token_ids - req_id = request.req_id - assert isinstance(token_ids, torch.Tensor) - assert token_ids.is_cpu - - skip_leading_tokens = max( - self.lookup(token_ids, self.use_layerwise), - save_spec.skip_leading_tokens, - ) - if skip_leading_tokens == len(token_ids): - if request.is_last_chunk: - self.kv_send_thread.set_finished_request( # type: ignore[union-attr] - req_id) - continue # skip this request - - skip_leading_tokens = (skip_leading_tokens // self.block_size * - self.block_size) - - store_mask = torch.ones_like(token_ids, dtype=torch.bool) - store_mask[:skip_leading_tokens] = False - - logger.info( - "Storing KV cache for %d out of %d tokens " - "(skip_leading_tokens=%d) for request %s", - len(token_ids) - skip_leading_tokens, - len(token_ids), - skip_leading_tokens, - request.req_id, - ) - - self.kv_send_thread.add_request( # type: ignore[union-attr] - req_id, - token_ids, - request.block_ids, - store_mask, - request.is_last_chunk, - ) - - def retrieve_layer( - self, - req_id: str, - tokens: torch.Tensor, - block_ids: list[int], - mask: Optional[torch.Tensor] = None, - ) -> Generator[Optional[torch.Tensor], None, None]: - """ - Retrieve the KV cache in a layerwise manner. - - :param torch.Tensor tokens: The tokens of the corresponding KV caches. - - :param Optional[torch.Tensor] mask: The mask for the tokens. Should - have the same length as tokens. And the mask should ALWAYS be like - FFFFFTTTTTTT, where True means the tokens needs to be matched. - - :param **kwargs: The additional arguments for the KV transfer which - will be passed into the npu_transfer. - - return: A generator that yields Optional[torch.Tensor]. The tensor will - be the boolean mask indicating which tokens are retrieved and will - only be returned in the last iteration. - """ - - if mask is not None: - num_required_tokens = torch.sum(mask).item() - else: - num_required_tokens = len(tokens) - - ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") - - starts = [] - ends = [] - keys = [] - first_flag = True - for start, end, key in self.token_database.process_tokens( - tokens, mask): - keys_multi_layer = key.split_layers(self.num_layers) - starts.append(start) - ends.append(end) - keys.append(keys_multi_layer) - ret_mask[start:end] = True - - if keys: - # Transpose the keys into layer major format - keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] - for layer_id, keys_multi_chunk in enumerate(keys): - if not first_flag: - is_finish = self.get_event.wait(timeout=3) #try---cache - if not is_finish: - logger.info("Layerwise get failed") - self.get_event.clear() - req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, - starts, ends, block_ids, - layer_id) - self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] - req_meta) # type: ignore[union-attr, call-arg, arg-type] - first_flag = False - yield None - else: - # If no cache are found, we still need to yield to avoid - # `StopIteration` - for layer_id in range(self.num_layers): - yield None - - retrieved_tokens = torch.sum(ret_mask) - logger.debug(f"Retrieved {retrieved_tokens} " - f"out of {num_required_tokens} " - f"out of total {len(tokens)} tokens") - - yield ret_mask - - def store_layer( - self, - req_id: str, - tokens: torch.Tensor, - block_ids: list[int], - mask: Optional[torch.Tensor] = None, - ) -> Generator[None, None, None]: - """ - Store the KV cache in a layerwise manner. - - :param torch.Tensor tokens: The tokens of the corresponding KV caches. - - :param Optional[torch.Tensor] mask: The mask for the tokens. Should - have the same length as tokens. And the mask should ALWAYS be like - FFFFFTTTTTTT, where True means the tokens needs to be matched. - - :param **kwargs: The additional arguments for the storage backend which - will be passed into the gpu_connector. - - return: A generator that yields None. In the first iteration, the - generator allocates the memory objects for all layers and moves - the KV cache of the first layer from GPU to CPU. In the next - iterations, it moves the KV cache of layer i from GPU to the memory - objects (on CPU) and puts the memory objects of layer i-1 to the - storage backends. In the last iteration, it puts the memory objects - of the last layer to the storage backends. - """ - - if mask is not None: - num_stored_tokens = torch.sum(mask).item() - else: - num_stored_tokens = len(tokens) - - starts = [] - ends = [] - keys = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - keys_multi_layer = key.split_layers(self.num_layers) - starts.append(start) - ends.append(end) - keys.append(keys_multi_layer) #[block_num,layer_num] - - if keys: - keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] - for layer_id, keys_multi_chunk in enumerate(keys): - req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, - starts, ends, block_ids, - layer_id) - self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] - req_meta) # type: ignore[union-attr, call-arg, arg-type] - yield - else: - for layer_id in range(self.num_layers): - yield - logger.debug( - f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") - - def get_finished(self) -> tuple[set[str], set[str]]: - done_sending = ( - self.kv_send_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) - - done_recving = ( - self.kv_recv_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.load_async else set()) - - logger.debug( - "Number of completed KV cache send requests: %d, receive " - "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), - self.tp_rank) - return done_sending, done_recving - - def wait_layer_transfer_finish(self): - time.sleep(10) - pass - - def lookup( - self, - tokens: Union[torch.Tensor, List[int]], - use_layerwise: bool, - ) -> int: - """ - Checks the existence of KV cache of the tokens from the cache engine. - :param tokens: the input tokens, with shape [seq_len] - :return: An int indicating how many prefix tokens are cached. - """ - end = 0 - keys = [] - try: - if use_layerwise: - for start, end, key in self.token_database.process_tokens( - tokens): - keys_multi_layer = key.split_layers(self.num_layers) - for item in keys_multi_layer: - keys.append(item.to_string()) - # batch is_exists - ress = self.m_store.batch_exists(keys) - res = 1 - for value in ress: - if value != 1: - res = 0 - break - if res == 1: - continue - else: - return start - else: - starts = [] - for start, end, key in self.token_database.process_tokens( - tokens): - keys.append(key.to_string()) - starts.append(start) - res = self.m_store.batch_exists( - keys) # type: ignore[assignment] - for index, value in enumerate(res): # type: ignore[arg-type] - if value != 1: - return starts[index] - # all tokens where found, return the maximal end - except Exception as e: - logger.error(f"Remote connection failed in contains: {e}") - return start - return end - - def lookup_scheduler( - self, - tokens: Union[torch.Tensor, List[int]], - use_layerwise: bool, - ) -> int: - """ - Checks the existence of KV cache of the tokens from the cache engine. - :param tokens: the input tokens, with shape [seq_len] - :return: An int indicating how many prefix tokens are cached. - """ - end = 0 - keys = [] - try: - if use_layerwise: - for start, end, key in self.token_database.process_tokens( - tokens): - keys_multi_layer = key.split_layers(self.num_layers) - for item in keys_multi_layer: - keys.append(item.to_string()) - # batch is_exists - ress = self.m_store.batch_exists(keys) - res = 1 - for value in ress: - if value != 1: - res = 0 - break - if res == 1: - continue - else: - return start - else: - starts = [] - for start, end, key in self.token_database.process_tokens( - tokens): - keys.append(key.to_string()) - starts.append(start) - multi_tp_keys = keys[:] - for i in range(1, self.tp_size): - for item in keys: - new_str = item.replace( # type: ignore[attr-defined] - "@0", f"@{i}", 1) - multi_tp_keys.append(new_str) - res = self.m_store.batch_exists( - multi_tp_keys) # type: ignore[assignment] - num_block = len(keys) - multi_tp_values = [ - res[i * num_block:(i + 1) * - num_block] # type: ignore[index] - for i in range(self.tp_size) - ] - index = self.find_min_first_non_one_index(multi_tp_values) - if index != -1: - return starts[index] - # all tokens where found, return the maximal end - except Exception as e: - logger.error(f"Remote connection failed in contains: {e}") - return start - return end - - def find_min_first_non_one_index(self, arr): - try: - return min(idx for row in arr for idx, val in enumerate(row) - if val != 1) - except ValueError: - return -1 - - def close(self) -> None: - """Close the cache engine and free all the resources""" - self.m_store.close() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py deleted file mode 100644 index cee07c6a251..00000000000 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ /dev/null @@ -1,126 +0,0 @@ -# Standard -import os - -# Third Party -from mooncake.store import ReplicateConfig # type: ignore -from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.utils import get_ip, logger - -from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey -from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te - -from .config_data import MooncakeStoreConfig - -METADATA_BYTES_LEN = 24 -BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) - - -class Mooncakestore(): - - def __init__(self, parallel_config: ParallelConfig): - try: - from mooncake.store import MooncakeDistributedStore # type: ignore - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e - tp_rank = get_tensor_model_parallel_rank() - tp_size = parallel_config.tensor_parallel_size - dp_rank = parallel_config.data_parallel_rank_local - all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) - if not all_device_ids: - device_ids_list = list( - range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) - else: - device_ids_list = list(map(int, all_device_ids.split(','))) - assert len(device_ids_list) > tp_rank - device_id = device_ids_list[tp_rank] - self.config = MooncakeStoreConfig.load_from_env() - self.store = MooncakeDistributedStore() - if self.config.protocol == "ascend" and not self.config.use_ascend_direct: - local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \ - ":npu_" + str(device_id) - ret = self.store.setup(local_hostname, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address) - else: - local_hostname = get_ip() - transfer_engine = get_global_te(local_hostname, device_name=None) - self.local_seg = local_hostname + ":" + str( - transfer_engine.get_rpc_port()) - ret = self.store.setup(self.local_seg, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - transfer_engine.get_engine()) - if ret != 0: - msg = "Initialize mooncake failed." - logger.error(msg) - raise RuntimeError(msg) - - def exists(self, key: MooncakeEngineKey) -> bool: - return self.store.is_exist(key.to_string()) == 1 - - def batch_exists(self, keys: list[str]) -> list[int]: - return self.store.batch_is_exist(keys) - - def register_buffer(self, ptr, length): - return self.store.register_buffer(ptr, length) - - def get_batch(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]], block_ids: list[int]): - try: - res = self.store.batch_get_into_multi_buffers( - keys, addrs, sizes, True) - for value in res: - if value < 0: - logger.error(f"Failed to get key {keys},res:{res}") - except Exception as e: - logger.error(f"Failed to get key {keys}. {e}") - - def put_batch(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]], block_ids: list[int]): - try: - config = ReplicateConfig() - config.preferred_segment = self.local_seg - config.prefer_alloc_in_same_node = True - res = self.store.batch_put_from_multi_buffers( - keys, addrs, sizes, config) - for value in res: - if value < 0: - logger.error(f"Failed to put key {keys},res:{res}") - except Exception as e: - logger.error(f"Failed to put key {keys},error:{e}") - - def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): - expect_res = sum(size) - key_str = key.to_string() - try: - res = self.store.batch_get_into_ascend(key_str, addr, size) - if res[0] != expect_res: - logger.error(f"Failed to get key: [{key_str}] .") - except Exception: - logger.error(f"Failed to get key: [{key_str}] .") - return res - - def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): - key_str = key.to_string() - try: - ret = self.store.batch_put_from_ascend(key_str, addr, size) - if ret[0] != 0: - logger.error(f"Failed to put key {key_str}.") - except Exception: - logger.error(f"Failed to put key {key_str}.") - - return ret - - def close(self): - self.store.close() - logger.info("Closed the mooncake store connection") diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py deleted file mode 100644 index d2fd0040425..00000000000 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ /dev/null @@ -1,514 +0,0 @@ -import threading -from typing import Any, Optional - -import torch -import vllm.envs as envs -import zmq -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.forward_context import ForwardContext -from vllm.utils import logger -from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import Request -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder - -from vllm_ascend.distributed.mooncake.config_data import ( - LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) -from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import make_zmq_socket -else: - from vllm.utils.network_utils import make_zmq_socket - - -class MooncakeConnectorV1(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self.kv_role = vllm_config.kv_transfer_config.kv_role - - self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "use_layerwise", False) - - self.kv_caches: dict[str, torch.Tensor] = {} - - self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size - - self._block_size = vllm_config.cache_config.block_size - - if self.pcp_size > 1: - self._block_size *= self.pcp_size - if self.dcp_size > 1: - self._block_size *= self.dcp_size - - self.sended_but_unfinished_reqs: set[str] = set() - - if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( - vllm_config, self.use_layerwise) - else: - self.connector_worker = MooncakeEngine( - vllm_config, - self.use_layerwise, - ) - - assert self.connector_worker is not None - if vllm_config.parallel_config.rank == 0: - self.lookup_server = MooncakeLookupServer( - self.connector_worker, vllm_config, self.use_layerwise) - - ############################################################ - # Scheduler Side Methods - ############################################################ - - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: - assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta(scheduler_output) - - def request_finished( - self, - request: "Request", - block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: - assert self.connector_scheduler is not None - return self.connector_scheduler.request_finished(request, block_ids) - - ############################################################ - # Worker Side Methods - ############################################################ - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - assert self.connector_worker is not None - self.connector_worker.register_kv_caches(kv_caches) - - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - assert self.connector_worker is not None - assert isinstance(self._get_connector_metadata(), - MooncakeConnectorMetadata) - self.connector_worker.start_load_kv(self._get_connector_metadata()) - - def wait_for_layer_load(self, layer_name: str) -> None: - """MooncakeStoreConnector does not do layerwise saving.""" - if not self.use_layerwise: - return - self.connector_worker.wait_for_layer_load() - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """MooncakeStoreConnector does not save explicitly.""" - if not self.use_layerwise: - return - - if self.kv_role == "kv_consumer": - # Don't do save if the role is kv_consumer - return - self.connector_worker.save_kv_layer(self._get_connector_metadata()) - - def wait_for_save(self): - """MooncakeStoreConnector does not save explicitly.""" - if self.kv_role == "kv_consumer": - # Don't do save if the role is kv_consumer - return - - if self.use_layerwise: - self.connector_worker.wait_layer_transfer_finish() - return - - self.connector_worker.wait_for_save(self._get_connector_metadata()) - - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """Get the finished recving and sending requests.""" - assert self.connector_worker is not None - meta = self._get_connector_metadata() - done_sending, done_recving = self.connector_worker.get_finished() - sended_and_finished: set[str] = set() - for item in list(self.sended_but_unfinished_reqs): - if item not in meta.unfinished_request_ids: - sended_and_finished.add(item) - self.sended_but_unfinished_reqs.remove(item) - for item in done_sending: - if item in meta.unfinished_request_ids: - self.sended_but_unfinished_reqs.add(item) - else: - sended_and_finished.add(item) - - return sended_and_finished, done_recving - - -def get_zmq_rpc_path_mooncake( - vllm_config: Optional["VllmConfig"] = None, ) -> str: - base_url = envs.VLLM_RPC_BASE_PATH - # Default to 0 if not configured - rpc_port = 0 - if vllm_config is not None: - rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( - "mooncake_rpc_port", 0) - logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) - return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" - - -class MooncakeStoreConnectorV1Scheduler: - - def __init__(self, vllm_config: "VllmConfig", use_layerwise): - self.client = MooncakeLookupClient(vllm_config) - self.use_layerwise = use_layerwise - self.kv_role = vllm_config.kv_transfer_config.kv_role - self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "consumer_is_to_load", False) - self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "load_async", False) - # request_id -> (vllm cached tokes, mooncake cached tokens) - self.load_specs: dict[str, LoadSpec] = {} - self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size - - self._block_size = vllm_config.cache_config.block_size - - if self.pcp_size > 1: - self._block_size *= self.pcp_size - if self.dcp_size > 1: - self._block_size *= self.dcp_size - # request_id -> full_token_ids - self._request_trackers: dict[str, RequestTracker] = {} - # Whether to discard partial chunks - self._discard_partial_chunks = ( - vllm_config.kv_transfer_config.get_from_extra_config( - "discard_partial_chunks", True)) - self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} - self._unfinished_request_ids: set[str] = set() - - def get_num_new_matched_tokens( - self, - request: "Request", - num_computed_tokens: int, - ) -> tuple[int, bool]: - """ - Check for external KV cache hit. - - Args: - request (Request): the request object. - num_computed_tokens (int): the number of locally - computed tokens for this request - - Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. - """ - if self.kv_role == "kv_consumer" and not self.consumer_is_to_load: - return 0, False - - if self._discard_partial_chunks: - token_block_end = len(request.prompt_token_ids - ) // self._block_size * self._block_size - token_ids = torch.tensor( - request.prompt_token_ids[:token_block_end]) - else: - token_ids = torch.tensor(request.prompt_token_ids) - - num_external_hit_tokens = self.client.lookup(token_ids) - - if num_external_hit_tokens == request.num_tokens: - num_external_hit_tokens -= 1 - - need_to_allocate = num_external_hit_tokens - num_computed_tokens - - logger.info( - "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", - request.request_id, - request.num_tokens, - num_external_hit_tokens, - need_to_allocate, - ) - - if need_to_allocate <= 0: - return 0, False - - self.load_specs[request.request_id] = LoadSpec( - vllm_cached_tokens=num_computed_tokens, - mooncake_cached_tokens=num_external_hit_tokens, - can_load=False, - ) - - return need_to_allocate, self.load_async - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - """ - Update KVConnector state after temporary buffer alloc. - - For SharedStorageConnector, update _request_needs_load - if the CacheManager this allocated blocks for us. - """ - local_block_ids = [] - if num_external_tokens > 0: - local_block_ids = blocks.get_block_ids()[0] - - self._unfinished_requests[request.request_id] = (request, - local_block_ids) - self._unfinished_request_ids.add(request.request_id) - if request.request_id not in self.load_specs: - # No KV tokens from external KV cache, return - return - - if num_external_tokens == 0: - # No need to load anything - self.load_specs[request.request_id].can_load = False - return - - assert ( - num_external_tokens > 0 and num_external_tokens - == self.load_specs[request.request_id].mooncake_cached_tokens - - self.load_specs[request.request_id].vllm_cached_tokens - ), (f"Mismatch in number of tokens: {num_external_tokens} vs " - f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " - f"{self.load_specs[request.request_id].vllm_cached_tokens}" - f" for request {request.request_id}") - - self.load_specs[request.request_id].can_load = True - - def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: - """Attach the connector metadata to the request object. - - This function should NOT modify other fields in the scheduler_output - except the `kv_connector_metadata` field. - Also, calling this function will reset the state of the connector. - - Args: - scheduler_output (SchedulerOutput): the scheduler output object. - """ - - force_skip_save = self.kv_role == "kv_consumer" - - for finished_req_id in scheduler_output.finished_req_ids: - self._request_trackers.pop(finished_req_id, None) - self._unfinished_requests.pop(finished_req_id, None) - self._unfinished_request_ids.discard(finished_req_id) - - meta = MooncakeConnectorMetadata(self._unfinished_request_ids) - - for request in scheduler_output.scheduled_new_reqs: - # Right now, we only load KV for new requests - load_spec = self.load_specs.pop(request.req_id, None) - num_tokens_to_compute = ( - request.num_computed_tokens + - scheduler_output.num_scheduled_tokens[request.req_id]) - request_tracker = RequestTracker.from_new_request( - request, num_tokens_to_compute) - self._request_trackers[request.req_id] = request_tracker - last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else len( - request.prompt_token_ids)) - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=load_spec, - skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) - >= last_chunk_tokens_num, - discard_partial_chunks=self._discard_partial_chunks, - ) - if req_meta is not None: - meta.add_request(req_meta) - - cached_reqs = scheduler_output.scheduled_cached_reqs - if isinstance(cached_reqs, list) and not force_skip_save: - for i, req in enumerate(cached_reqs): - request_tracker = self._request_trackers[req.req_id] - request_tracker.update(req.new_token_ids, req.new_block_ids) - last_chunk_tokens_num = ((len(req.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(req.prompt_token_ids)) - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=None, - skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) - >= last_chunk_tokens_num, - discard_partial_chunks=self._discard_partial_chunks, - ) - if req_meta is not None: - meta.add_request(req_meta) - elif not force_skip_save: - for i, req_id in enumerate(cached_reqs.req_ids): - request_tracker = self._request_trackers[req_id] - num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] - req_tuple = self._unfinished_requests.get(req_id) - if req_tuple: - request = req_tuple[0] - num_current_tokens = len(request_tracker.token_ids) - new_token_ids = request.all_token_ids[ - num_current_tokens:num_current_tokens + num_new_tokens] - else: - raise ValueError( - f"Request {req_id} is not in _unfinished_requests, " - f"but it is scheduled to be cached") - new_block_ids = cached_reqs.new_block_ids[i] - if not new_block_ids: - continue - request_tracker.update(new_token_ids, new_block_ids) - # decode not save - if len(request_tracker.token_ids) > len( - request.prompt_token_ids): - continue - - last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(request.prompt_token_ids)) - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=None, - skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) - >= last_chunk_tokens_num, - discard_partial_chunks=self._discard_partial_chunks, - ) - if req_meta is not None: - meta.add_request(req_meta) - - request_ids = [ - req.req_id for req in scheduler_output.scheduled_new_reqs - ] - for request_id, (request, - block_ids) in self._unfinished_requests.items(): - if request_id not in request_ids and request_id not in cached_reqs.req_ids: - load_spec = self.load_specs.pop(request_id, None) - if not load_spec: - continue - num_tokens_to_compute = load_spec.mooncake_cached_tokens - if (num_tokens_to_compute % self._block_size - != 0) and (num_tokens_to_compute - == len(request.prompt_token_ids) - 1): - num_tokens_to_compute = num_tokens_to_compute + 1 - request_tracker = RequestTracker( - req_id=request_id, - token_ids=request.prompt_token_ids[:num_tokens_to_compute]. - copy(), - allocated_block_ids=block_ids, - num_saved_tokens=0, - ) - - self._request_trackers[request_id] = request_tracker - - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=load_spec, - skip_save=None, - discard_partial_chunks=self._discard_partial_chunks, - ) - if req_meta is not None: - meta.add_request(req_meta) - return meta - - def request_finished( - self, - request: "Request", - block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: - """ - Once a request is finished, determine whether request blocks - should be freed now or will be sent asynchronously and freed later. - """ - if self.kv_role == "kv_consumer": - return False, None - tracker = self._request_trackers.get(request.request_id) - if tracker is not None and tracker.num_saved_tokens <= 0: - return False, None - delay_free_blocks = len(block_ids) > 0 - if delay_free_blocks: - logger.info("Delaying free of %d blocks for request %s", - len(block_ids), request.request_id) - return delay_free_blocks, None - - -class MooncakeLookupClient: - - def __init__(self, vllm_config: "VllmConfig"): - self.encoder = MsgpackEncoder() - self.ctx = zmq.Context() # type: ignore[attr-defined] - socket_path = get_zmq_rpc_path_mooncake(vllm_config) - self.socket = make_zmq_socket( - self.ctx, - socket_path, - zmq.REQ, # type: ignore[attr-defined] - bind=False, - ) - - def lookup(self, token_ids: torch.Tensor) -> int: - request = self.encoder.encode(token_ids) - self.socket.send_multipart(request, copy=False) - resp = self.socket.recv() - result = int.from_bytes(resp, "big") - return result - - def close(self): - self.socket.close(linger=0) - - -class MooncakeLookupServer: - - def __init__( - self, - mooncake_engine: MooncakeEngine, - vllm_config: "VllmConfig", - use_layerwise: bool, - ): - self.decoder = MsgpackDecoder(torch.Tensor) - self.ctx = zmq.Context() # type: ignore[attr-defined] - socket_path = get_zmq_rpc_path_mooncake(vllm_config) - self.socket = make_zmq_socket( - self.ctx, - socket_path, - zmq.REP, # type: ignore[attr-defined] - bind=True, - ) - - self.mooncake_engine = mooncake_engine - self.running = True - - def process_request(): - while self.running: - frames = self.socket.recv_multipart(copy=False) - token_ids = self.decoder.decode(frames) - result = self.mooncake_engine.lookup_scheduler( - token_ids, use_layerwise) - response = result.to_bytes(4, "big") - self.socket.send(response) - - self.thread = threading.Thread(target=process_request, daemon=True) - self.thread.start() - - def close(self): - self.socket.close(linger=0) - # TODO: close the thread! diff --git a/vllm_ascend/distributed/mooncake/transfer_engine.py b/vllm_ascend/distributed/mooncake/transfer_engine.py deleted file mode 100644 index e515da677b2..00000000000 --- a/vllm_ascend/distributed/mooncake/transfer_engine.py +++ /dev/null @@ -1,28 +0,0 @@ -import threading -from typing import Optional - -from mooncake.engine import TransferEngine # type: ignore - -_global_te = None -_global_te_lock = threading.Lock() - - -def get_global_te(hostname: str, device_name: Optional[str]): - global _global_te - if _global_te is None: - with _global_te_lock: - # Double-Checked Locking - if _global_te is None: - if TransferEngine is None: - raise RuntimeError("mooncake is not available") - transfer_engine = TransferEngine() - device_name = device_name if device_name is not None else "" - ret_value = transfer_engine.initialize(hostname, - "P2PHANDSHAKE", - "ascend", device_name) - if ret_value != 0: - raise RuntimeError( - f"TransferEngine initialization failed with ret_value: {ret_value}" - ) - _global_te = transfer_engine - return _global_te diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 7951760d1de..754bba7b68b 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -25,22 +25,29 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - get_tp_group) +from vllm.distributed.parallel_state import ( + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_tensor_model_parallel_rank, get_tp_group) from vllm.utils import logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config -from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te +from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import vllm_version_is +from vllm_ascend.utils import prefill_context_parallel_enable -if vllm_version_is("0.11.0"): - from vllm.utils import get_ip, make_zmq_path, make_zmq_socket -else: - from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +# isort: off +if prefill_context_parallel_enable(): + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) +# isort: on + +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -66,6 +73,8 @@ class ReqMeta: remote_host: str remote_port: int remote_engine_id: str + remote_pcp_size: int + remote_dcp_size: int class KVCacheTaskTracker: @@ -139,19 +148,21 @@ def _remove_delayed_requests(self, request_id: str): class KVCacheSendingThread(threading.Thread): - def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, - side_channel_host: str, side_channel_port: int, - metadata: MooncakeAgentMetadata, ready_event: threading.Event, - kv_caches: dict[str, Any]): + def __init__(self, tp_rank: int, prefill_tp_size: int, + local_engine_id: str, side_channel_host: str, + side_channel_port: int, metadata: MooncakeAgentMetadata, + ready_event: threading.Event, kv_caches: dict[str, Any], + pcp_rank: int): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank - self.decode_tp_size = decode_tp_size + self.prefill_tp_size = prefill_tp_size self.local_engine_id = local_engine_id self.side_channel_host = side_channel_host self.side_channel_port = side_channel_port self.metadata = metadata self.ready_event = ready_event self.kv_caches = kv_caches + self.pcp_rank = pcp_rank self.task_tracker = KVCacheTaskTracker() @@ -183,7 +194,8 @@ def run(self): # NOTE(rob): we need each rank to have a unique port. This hack to keeps # us moving. We will switch when moving to etcd or where we have a # single ZMQ socket in the scheduler. - handshake_port = self.side_channel_port + self.tp_rank + handshake_port = self.side_channel_port + self.pcp_rank * self.prefill_tp_size \ + + self.tp_rank path = make_zmq_path("tcp", self.side_channel_host, handshake_port) logger.info("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore @@ -282,7 +294,7 @@ def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, def add_request(self, request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, remote_host: str, remote_handshake_port: int, offset: int, - num_need_pulls: int): + num_need_pulls: int, all_task_done: bool): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") self.request_queue.put({ @@ -293,7 +305,8 @@ def add_request(self, request_id: str, local_block_ids: list[int], "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, "offset": offset, - "num_need_pulls": num_need_pulls + "num_need_pulls": num_need_pulls, + "all_task_done": all_task_done }) def get_and_clear_finished_requests(self) -> set[str]: @@ -322,8 +335,7 @@ def _handle_request(self, req_meta: dict[str, Any]): request_id = req_meta["request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] - offset = req_meta["offset"] - num_need_pulls = req_meta["num_need_pulls"] + all_task_done = req_meta["all_task_done"] try: logger.debug( @@ -340,7 +352,7 @@ def _handle_request(self, req_meta: dict[str, Any]): # remote host. self._send_done_recv_signal(request_id, remote_host, remote_handshake_port) - if offset == num_need_pulls - 1: + if all_task_done: self.task_tracker.update_done_task_count(request_id) self.request_queue.task_done() @@ -616,12 +628,17 @@ def add_new_req( remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], + remote_pcp_size=kv_transfer_params["remote_pcp_size"], + remote_dcp_size=kv_transfer_params["remote_dcp_size"], ) class MooncakeConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id @@ -713,14 +730,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \ + if prefill_context_parallel_enable() else 1 + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ - vllm_config.parallel_config.data_parallel_size + vllm_config.parallel_config.data_parallel_size * \ + self.pcp_size # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) # Requests that need to start recv. # New requests are added by update_state_after_alloc in @@ -848,6 +869,8 @@ def request_finished( remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + remote_pcp_size=self.pcp_size, + remote_dcp_size=self.dcp_size, last_token_id=request.output_token_ids[-1], ) @@ -875,7 +898,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() - self.max_device_id = self.tp_size * self.dp_size + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + + self.max_device_id = self.tp_size * self.dp_size * self.pcp_size self.kv_role = vllm_config.kv_transfer_config.kv_role self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads @@ -883,8 +914,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) - self.handshake_port = self.side_channel_port + self.tp_rank + vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) + self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank self.sockets: dict = {} # get tp device id @@ -893,20 +924,23 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): device_ids_str = envs_ascend.PHYSICAL_DEVICES if device_ids_str is None: device_ids = list( - range(self.dp_rank * self.tp_size, - (self.dp_rank + 1) * self.tp_size)) + range(self.dp_rank * self.tp_size * self.pcp_size, + (self.dp_rank + 1) * self.tp_size * self.pcp_size)) else: device_ids = list(map(int, device_ids_str.split(','))) - start_index = self.dp_rank * self.tp_size - end_index = start_index + self.tp_size + start_index = self.dp_rank * self.tp_size * self.pcp_size + end_index = start_index + self.tp_size * self.pcp_size if len(device_ids) < end_index: raise ValueError( f"Not enough physical devices available for DP rank {self.dp_rank}. " f"Expected at least {end_index} devices, but found {len(device_ids)} " "in PHYSICAL_DEVICES.") device_ids = device_ids[start_index:end_index] - assert len(device_ids) > self.tp_rank # type: ignore - self.device_id = device_ids[self.tp_rank] # type: ignore + assert len( + device_ids + ) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore + self.device_id = device_ids[self.pcp_rank * self.tp_size + + self.tp_rank] # type: ignore if vllm_config.kv_transfer_config.get_from_extra_config( 'use_ascend_direct', True): @@ -914,7 +948,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): else: hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" logger.info("Initializing Mooncake work %s", engine_id) - self.engine = get_global_te(hostname, device_name=None) + self.engine = global_te.get_transfer_engine(hostname, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. @@ -1024,6 +1058,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches = kv_caches kv_caches_base_addr = [] + ptrs = [] + lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: @@ -1031,13 +1067,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) elif self.use_sparse: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 3] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) else: cache_list = [ cache_or_caches @@ -1046,8 +1084,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) - + ptrs.append(base_addr) + lengths.append(region_len) + global_te.register_buffer(ptrs, lengths) # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( engine_id=self.engine_id, @@ -1059,9 +1098,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event = threading.Event() if self.kv_role == 'kv_producer': self.kv_send_thread = KVCacheSendingThread( - self.tp_rank, self._decode_tp_size, self.engine_id, + self.tp_rank, self._prefill_tp_size, self.engine_id, self.side_channel_host, self.side_channel_port, metadata, - ready_event, self.kv_caches) + ready_event, self.kv_caches, self.pcp_rank) self.kv_send_thread.start() else: self.kv_recv_thread = KVCacheRecvingThread( @@ -1071,14 +1110,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_recv_thread.start() ready_event.wait() - def _register(self, ptr, length): - logger.debug( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - ret_value = self.engine.register_memory(ptr, length) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed.") - def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( self.kv_send_thread. @@ -1094,6 +1125,92 @@ def get_finished(self) -> tuple[set[str], set[str]]: "requests: %d", len(done_sending), len(done_recving)) return done_sending, done_recving + def _get_kv_split_metadata( + self, + req_id: str, + meta: ReqMeta, + ) -> tuple[list[list[int]], list[list[int]], list[list[int]]]: + """ + In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. + Use this function to calculate remote port and remote block number of each remote P node that we need to pull. + """ + if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: + choosen_rank_list = self._get_remote_tp_rank(req_id) + remote_handshake_port_list = [[ + x + meta.remote_port for x in choosen_rank_list + ]] + local_block_ids_list, remote_block_ids_list = [ + meta.local_block_ids + ], [meta.remote_block_ids] + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + + if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size: + # remote & local cp/dcp are equal, do kv transfer point-to-point + remote_kv_num = 1 + remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \ + for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)] + remote_block_nums = [len(meta.remote_block_ids)] + else: + assert self.pcp_size == 1 + if self.use_mla: + assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \ + (self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size) + else: + assert self.tp_size == self._prefill_tp_size and ( + self.dcp_size == 1 + or self.dcp_size == meta.remote_dcp_size) + # remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes + # 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size) + # 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size + remote_dcp_size = meta.remote_dcp_size // self.dcp_size + remote_kv_num = meta.remote_pcp_size * remote_dcp_size + cp_dcp_offsets = [] + for cp_idx in range(meta.remote_pcp_size): + cp_offset = cp_idx * self._prefill_tp_size + cp_dcp_offsets += list( + range(cp_offset, cp_offset + remote_dcp_size)) + tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size + remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \ + for cp_dcp_offset in cp_dcp_offsets] + # recompute cp/dcp block assign here, maybe we can also pass it from P node meta + local_block_num = len(meta.local_block_ids) + remote_block_nums = [ + local_block_num // (meta.remote_pcp_size * remote_dcp_size) + ] * meta.remote_pcp_size * remote_dcp_size + num_remain_blocks = local_block_num % (meta.remote_pcp_size * + remote_dcp_size) + for i in range(num_remain_blocks): + remote_block_nums[i] += 1 + # make sure the last block (which may be unfull) of P nodes is put to the last block of D node + remote_ports = remote_ports[ + num_remain_blocks:] + remote_ports[:num_remain_blocks] + remote_block_nums = remote_block_nums[ + num_remain_blocks:] + remote_block_nums[:num_remain_blocks] + + remote_handshake_port_list = [] + for remote_kv_id in range(remote_kv_num): + remote_handshake_port_list.append([remote_ports[remote_kv_id]]) + + # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list + # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]], + # remote_handshake_port_list[[30000],[30001],[30004],[30005]] + # D rank will get remote block 1 in port 30004 and save it in local block 5 + local_block_ids_list = [] + remote_block_ids_list = [] + local_block_offset = 0 + for remote_kv_id in range(len(remote_handshake_port_list)): + num_blocks_to_pull = remote_block_nums[remote_kv_id] + remote_block_ids_list.append( + meta.remote_block_ids[:num_blocks_to_pull]) + local_block_ids_list.append( + meta.local_block_ids[local_block_offset:local_block_offset + + num_blocks_to_pull]) + local_block_offset += num_blocks_to_pull + assert local_block_offset == len(meta.local_block_ids), \ + f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})" + + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + def start_load_kv(self, metadata: MooncakeConnectorMetadata): """Start loading KV blocks from remote engine.""" for req_id, meta in metadata.requests.items(): @@ -1103,21 +1220,28 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - choosen_rank_list = self._get_remote_tp_rank(req_id) - remote_handshake_port_list = [ - x + meta.remote_port for x in choosen_rank_list - ] - for i in range(self.num_need_pulls): - assert self.kv_recv_thread is not None - self.kv_recv_thread.add_request( - request_id=req_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - remote_engine_id=meta.remote_engine_id, - remote_host=meta.remote_host, - remote_handshake_port=remote_handshake_port_list[i], - offset=i, - num_need_pulls=self.num_need_pulls) + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( + req_id, meta) + + for pcp_dcp_rank in range(len(remote_handshake_port_list)): + if len(local_block_ids_list[pcp_dcp_rank]) + len( + remote_block_ids_list[pcp_dcp_rank]) == 0: + continue + for i in range(self.num_need_pulls): + assert self.kv_recv_thread is not None + self.kv_recv_thread.add_request( + request_id=req_id, + local_block_ids=local_block_ids_list[pcp_dcp_rank], + remote_block_ids=remote_block_ids_list[pcp_dcp_rank], + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port_list[ + pcp_dcp_rank][i], + offset=i, + num_need_pulls=self.num_need_pulls, + all_task_done=(pcp_dcp_rank + == len(remote_handshake_port_list) - 1 + and i == self.num_need_pulls - 1)) if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items(): diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 1c5c0a92608..215becc5477 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -28,19 +28,15 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, get_tp_group, get_world_group) from vllm.utils import logger +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.utils import (align_memory, get_transfer_timeout_value, kv_alltoall_and_rearrange) -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import get_ip, make_zmq_path, make_zmq_socket -else: - from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -364,7 +360,10 @@ def add_new_req(self, class MooncakeLayerwiseConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index fdadfa24b00..cd148da3f32 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -50,11 +50,11 @@ # value is None, which means the system default C compiler will be used. "C_COMPILER": lambda: os.getenv("C_COMPILER", None), - # The version of the Ascend chip. If not set, the default value is - # ASCEND910B1(Available for A2 and A3 series). It's used for package building. + # The version of the Ascend chip. It's used for package building. + # If not set, we will query chip info through `npu-smi`. # Please make sure that the version is correct. "SOC_VERSION": - lambda: os.getenv("SOC_VERSION", "ASCEND910B1"), + lambda: os.getenv("SOC_VERSION", None), # If set, vllm-ascend will print verbose logs during compilation "VERBOSE": lambda: bool(int(os.getenv('VERBOSE', '0'))), @@ -165,11 +165,6 @@ # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), - # Timeout (in seconds) for delayed KVCache block release. In the prefill - # node, if a request is marked for delayed KV block release and the blocks - # are not freed within this timeout, they will be forcibly released. - "VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT": - lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), "VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), # Whether to enable transpose weight and cast format to FRACTAL_NZ. @@ -177,7 +172,10 @@ lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)), # Decide whether we should enable CP parallelism. "VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0'))) + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0'))), + # Whether to anbale dynamic EPLB + "DYNAMIC_EPLB": + lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), } # end-env-vars-definition diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py index 0a558682231..b43b85b6d46 100644 --- a/vllm_ascend/eplb/core/eplb_utils.py +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -22,31 +22,7 @@ import torch from vllm.logger import logger - -def determine_default_expert_map(global_expert_num, world_size, rank_id, - global_redundant_expert_num): - if world_size == 1: - local_ids = torch.arange(global_expert_num, dtype=torch.int32) - return (global_expert_num, local_ids) - - local_num_experts = global_expert_num // world_size - - expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32) - - if rank_id < world_size - 1: - start = rank_id * local_num_experts - end = (rank_id + 1) * local_num_experts - local_count = local_num_experts - else: - start = rank_id * local_num_experts - end = global_expert_num - local_count = global_expert_num - rank_id * local_num_experts - - if isinstance(local_count, int): - local_ids = torch.arange(local_count, dtype=torch.int32) - expert_map[start:end] = local_ids - - return (local_count, expert_map) +import vllm_ascend.envs as envs_ascend def generate_log2phy_map(expert_map): @@ -88,8 +64,7 @@ def generate_log2phy_map(expert_map): return log2phy_map -def determine_default_log2phy_map(global_expert_num, world_size, rank_id, - global_redundant_expert_num): +def determine_default_log2phy_map(global_expert_num, world_size, rank_id): if world_size == 1: local_ids = torch.arange(global_expert_num, dtype=torch.int32) expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1) @@ -140,9 +115,10 @@ def check_dynamic_eplb(dynamic_eplb): return if not isinstance(dynamic_eplb, bool): raise TypeError("The dynamic_eplb is not bool.") - if dynamic_eplb and os.getenv("DYNAMIC_EPLB", "false") != "true": + + if dynamic_eplb and envs_ascend.DYNAMIC_EPLB not in ("true", "1"): raise ValueError( - 'Can not enable dynamic_eplb when not export DYNAMIC_EPLB="true".' + 'Can not enable dynamic_eplb when DYNAMIC_EPLB is not set to "true" or "1".' ) @staticmethod diff --git a/vllm_ascend/kv_offload/cpu_npu.py b/vllm_ascend/kv_offload/cpu_npu.py index c19ec1b0b2e..7fe5b878612 100644 --- a/vllm_ascend/kv_offload/cpu_npu.py +++ b/vllm_ascend/kv_offload/cpu_npu.py @@ -2,17 +2,11 @@ import torch from vllm.attention import AttentionBackend from vllm.logger import init_logger +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, TransferResult, TransferSpec) -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import is_pin_memory_available -else: - from vllm.utils.platform_utils import is_pin_memory_available - logger = init_logger(__name__) diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index bf86501d72e..3dba7ee9d59 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -4,9 +4,9 @@ import torch -from vllm_ascend.utils import is_310p +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type -if is_310p(): +if get_ascend_device_type() == AscendDeviceType._310P: from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) @@ -349,64 +349,3 @@ def add_lora_logits(self, bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) y = y.view_as(y_org) - - -class PunicaWrapperNPU0110(PunicaWrapperNPU): - # NOTE: remove me when 0.11.0 id dropped - def add_lora_linear( # type: ignore[override] - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: - """ - Applicable to linear-related lora. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += ( - x[i].unsqueeze(0) - @ lora_a_stacked[indices[i], layer_idx, :, :] - @ lora_b_stacked[indices[i], layer_idx, :, :] - * scale - ).squeeze(0)+lora_bias_stacked[i] - - Args: - y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. - scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. - """ - - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - - if buffer is None: - r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default, consistent with the - # triton op - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) - self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) diff --git a/vllm_ascend/model_loader/netloader/netloader.py b/vllm_ascend/model_loader/netloader/netloader.py index d613d2a7802..2968ee366fe 100644 --- a/vllm_ascend/model_loader/netloader/netloader.py +++ b/vllm_ascend/model_loader/netloader/netloader.py @@ -29,18 +29,12 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading) - -from vllm_ascend.utils import vllm_version_is +from vllm.utils.torch_utils import set_default_torch_dtype from .interaction.elastic import ElasticServer from .load import elastic_load from .utils import find_free_port, is_valid_path_prefix -if vllm_version_is("0.11.0"): - from vllm.model_executor.model_loader.utils import set_default_torch_dtype -else: - from vllm.utils.torch_utils import set_default_torch_dtype - @register_model_loader("netloader") class ModelNetLoaderElastic(BaseModelLoader): @@ -207,10 +201,8 @@ def load_model(self, vllm_config: VllmConfig, if model is not None and ( (self.listen_port and self.listen_port in range(1024, 65535)) or (self.listen_port is None)): - if vllm_version_is("0.11.0"): - from vllm.utils import get_ip - else: - from vllm.utils.network_utils import get_ip + + from vllm.utils.network_utils import get_ip driver_ip = get_ip() if driver_ip == '0.0.0.0': diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 956df2eb315..31eae8d7cbe 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -1,7 +1,5 @@ from vllm import ModelRegistry -import vllm_ascend.envs as envs_ascend - def register_model(): ModelRegistry.register_model( @@ -10,24 +8,11 @@ def register_model(): ModelRegistry.register_model( "Qwen3VLMoeForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration" - ) + "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration") ModelRegistry.register_model( "Qwen3VLForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration" - ) - - if envs_ascend.USE_OPTIMIZED_MODEL: - ModelRegistry.register_model( - "Qwen2_5_VLForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" - ) - else: - ModelRegistry.register_model( - "Qwen2_5_VLForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" - ) + "vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration") # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 4ea4a27b14b..33049ffe1b6 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -24,32 +24,16 @@ import torch from torch import nn from vllm.attention import AttentionMetadata +from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.layers.mla import MLAModules +from vllm.model_executor.layers.mla import (MLAModules, + MultiHeadLatentAttentionWrapper) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils.torch_utils import direct_register_custom_op from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.attention import Attention - from vllm.model_executor.layers.mla import \ - MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper - from vllm.utils import direct_register_custom_op -else: - from vllm.attention.layer import MLAAttention - from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper - from vllm.utils.torch_utils import direct_register_custom_op - -if vllm_version_is("0.11.0"): - from vllm.attention import Attention - from vllm.model_executor.layers.mla import \ - MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper -else: - from vllm.attention.layer import MLAAttention - from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper class IndexerWrapper(nn.Module): @@ -81,7 +65,6 @@ def forward(self): return -# TODO(whx): adapt v0.11.0 and DSA class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): def __init__( @@ -119,61 +102,30 @@ def __init__( ascend_indexer = IndexerWrapper(mla_modules.indexer) else: ascend_indexer = None - - if vllm_version_is("0.11.0"): - self.mla_attn = Attention( - num_heads=num_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=scale, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - indexer=ascend_indexer, - use_sparse=mla_modules.is_sparse, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - qk_head_dim=self.qk_head_dim, - rotary_emb=mla_modules.rotary_emb, - fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, - q_b_proj=mla_modules.q_b_proj, - q_a_layernorm=mla_modules.q_a_layernorm, - q_proj=mla_modules.q_proj, - kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, - kv_a_layernorm=mla_modules.kv_a_layernorm, - kv_b_proj=mla_modules.kv_b_proj, - o_proj=mla_modules.o_proj, - ) - else: - self.mla_attn = MLAAttention( - num_heads=num_heads, - scale=scale, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - kv_b_proj=mla_modules.kv_b_proj, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_sparse=mla_modules.is_sparse, - indexer=ascend_indexer, - # extra args - rotary_emb=mla_modules.rotary_emb, - fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, - q_b_proj=mla_modules.q_b_proj, - q_a_layernorm=mla_modules.q_a_layernorm, - q_proj=mla_modules.q_proj, - kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, - kv_a_layernorm=mla_modules.kv_a_layernorm, - o_proj=mla_modules.o_proj, - ) + self.mla_attn = MLAAttention( + num_heads=num_heads, + scale=scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + kv_b_proj=mla_modules.kv_b_proj, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_sparse=mla_modules.is_sparse, + indexer=ascend_indexer, + # extra args + rotary_emb=mla_modules.rotary_emb, + fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, + q_b_proj=mla_modules.q_b_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + o_proj=mla_modules.o_proj, + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py deleted file mode 100644 index 6f07afdc61d..00000000000 --- a/vllm_ascend/models/qwen2_5_vl.py +++ /dev/null @@ -1,572 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Adapted from vllm/model_executor/models/qwen2_5_vl.py -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Callable, Iterable, Optional, Set, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_npu -from einops import rearrange -from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) -from vllm.config import VllmConfig -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import get_act_and_mul_fn -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, - Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer, - Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration, - Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo) -from vllm.model_executor.models.utils import maybe_prefix -from vllm.multimodal import MULTIMODAL_REGISTRY - -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz, - vllm_version_is) - -if not vllm_version_is("0.11.0"): - from vllm.model_executor.models.vision import conv3d_to_linear_weight - -MIN_PAD_SIZE = 64 # min_size to pad weight -MAX_PAD_SIZE = 128 # max_size to pad weight - - -class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.embed_dim = embed_dim - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: - # [s, b, 3 * head * head_dim] - seq_len, bs, _ = qkv.shape - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] - q, k, v = qkv.chunk(3, dim=2) - - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - q, k, v = (x.view(*new_shape) for x in (q, k, v)) - return q, k, v - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=cu_seqlens, - scale_value=self.origin_hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - - -class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, - quant_config, prefix) - self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - return x - - -class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding): - - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__(dim, theta) - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.inv_freq = inv_freq - - -class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): - - def __init__( - self, - vision_config: Qwen2_5_VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - interleaved=False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix) - norm_layer = partial(RMSNorm, eps=norm_eps) - self.interleaved = interleaved - self.enable_pad = False - head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // - 2) - self.patch_embed = AscendQwen2_5_VisionPatchEmbed( - patch_size=vision_config.patch_size, - temporal_patch_size=vision_config.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - - act_fn = get_act_and_mul_fn(vision_config.hidden_act) - self.blocks = nn.ModuleList([ - AscendQwen2_5_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.enable_pad = True - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 - self.half_pad_hidden_size_per_attention_head = ( - MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - if self.enable_pad: - cos = torch.nn.functional.pad( - cos, (0, self.half_pad_hidden_size_per_attention_head)) - sin = torch.nn.functional.pad( - sin, (0, self.half_pad_hidden_size_per_attention_head)) - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def pad_qkv_bias(self, bias): - first_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, :self.half_origin_hidden_size_per_attention_head] - second_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, self.half_origin_hidden_size_per_attention_head:] - first_half_padded = torch.nn.functional.pad( - first_half, (0, self.half_pad_hidden_size_per_attention_head)) - second_half_padded = torch.nn.functional.pad( - second_half, (0, self.half_pad_hidden_size_per_attention_head)) - bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) - bias_final = bias_padded.reshape(-1) - return bias_final - - def pad_qkv_weight(self, data): - qkv_weight_first_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, :self.half_origin_hidden_size_per_attention_head, :] - qkv_weight_second_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, self.half_origin_hidden_size_per_attention_head:, :] - - qkv_weight_first_half_padded = torch.nn.functional.pad( - qkv_weight_first_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_second_half_padded = torch.nn.functional.pad( - qkv_weight_second_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_padded = torch.cat( - [qkv_weight_first_half_padded, qkv_weight_second_half_padded], - dim=2) - qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - - if is_enable_nz(): - qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( - qkv_weight_final) - qkv_weight_final_copy = torch_npu.npu_format_cast( - qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) - return qkv_weight_final_copy - - return qkv_weight_final - - def pad_proj_weight(self, data): - out_weight = torch.nn.functional.pad( - data.reshape(self.hidden_size, -1, - self.half_origin_hidden_size_per_attention_head), - (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( - self.hidden_size, -1) - - if is_enable_nz(): - out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) - out_weight_copy = torch_npu.npu_format_cast( - out_weight_copy, ACL_FORMAT_FRACTAL_ND) - return out_weight_copy - - return out_weight - - def pad_qkv_weight_scale_offset(self, data): - reshaped_data = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, 1) - data1 = reshaped_data[:, :, :self. - half_origin_hidden_size_per_attention_head, :] - data2 = reshaped_data[:, :, self. - half_origin_hidden_size_per_attention_head:, :] - data1_paded = torch.nn.functional.pad( - data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, - 0, 0, 0)) - data2_paded = torch.nn.functional.pad( - data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, - 0, 0, 0)) - res = torch.cat([data1_paded, data2_paded], dim=2) - res = res.reshape(-1, 1) - return res - - def pad_qkv_deq_scale_quant_bias(self, data): - reshaped_data = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head) - data1 = reshaped_data[:, :, :self. - half_origin_hidden_size_per_attention_head] - data2 = reshaped_data[:, :, - self.half_origin_hidden_size_per_attention_head:] - - data1_paded = torch.nn.functional.pad( - data1, (0, self.half_pad_hidden_size_per_attention_head)) - data2_paded = torch.nn.functional.pad( - data2, (0, self.half_pad_hidden_size_per_attention_head)) - - res = torch.cat([data1_paded, data2_paded], dim=2) - res = res.reshape(-1) - return res - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("mlp.gate_up_proj.", "mlp.gate_proj.", 0), - ("mlp.gate_up_proj.", "mlp.up_proj.", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if not vllm_version_is("0.11.0"): - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - if ("attn.proj.weight_scale" in name or - "attn.proj.weight_offset" in name) and self.enable_pad: - continue - elif ("attn.proj.deq_scale" in name - or "attn.proj.quant_bias" in name) and self.enable_pad: - continue - elif ("attn.qkv.weight_scale" in name - or "attn.qkv.weight_offset" in name) and self.enable_pad: - param.data = self.pad_qkv_weight_scale_offset(param.data) - elif ("attn.qkv.deq_scale" in name - or "attn.qkv.quant_bias" in name) and self.enable_pad: - param.data = self.pad_qkv_deq_scale_quant_bias(param.data) - elif ("attn.proj.weight" in name) and self.enable_pad: - param.data = self.pad_proj_weight(param.data) - elif ("attn.qkv.weight" in name) and self.enable_pad: - param.data = self.pad_qkv_weight(param.data) - elif ("attn.qkv.bias" in name) and self.enable_pad: - param.data = self.pad_qkv_bias(param.data) - loaded_params.add(name) - return loaded_params - - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens - - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, - 0]).cpu().to(torch.int32) - - # patchify - x = self.patch_embed(x) - - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=x.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) - seq_len, _ = x.size() - x = x.reshape(seq_len // self.spatial_merge_unit, - self.spatial_merge_unit, -1) - x = x[window_index, :, :] - x = x.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - # transformers - x = x.unsqueeze(1) - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - else: - cu_seqlens_now = cu_window_seqlens - x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) - - # adapter - x = self.merger(x) - reverse_indices = torch.argsort(window_index) - x = x[reverse_indices, :] - return x - - -@MULTIMODAL_REGISTRY.register_processor( - Qwen2_5_VLMultiModalProcessor, - info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class AscendQwen2_5_VLForConditionalGeneration( - Qwen2_5_VLForConditionalGeneration): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen2_5_VisionTransformer( - vision_config=config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: - - grid_thw = image_input["image_grid_thw"] - assert grid_thw.ndim == 2 - - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - if vllm_version_is("0.11.0"): - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - else: - with set_ascend_forward_context(None, self.vllm_config): - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - - # Split concatenated embeddings for each image item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) - - def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: - - grid_thw = video_input["video_grid_thw"] - assert grid_thw.ndim == 2 - - if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) - else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - if vllm_version_is("0.11.0"): - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw) - else: - with set_ascend_forward_context(None, self.vllm_config): - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw) - - # Split concatenated embeddings for each video item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) diff --git a/vllm_ascend/models/qwen2_5_vl_without_padding.py b/vllm_ascend/models/qwen2_5_vl_without_padding.py deleted file mode 100644 index 6c3bbc8cfa6..00000000000 --- a/vllm_ascend/models/qwen2_5_vl_without_padding.py +++ /dev/null @@ -1,605 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_npu -from einops import rearrange -from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) - -try: - from transformers.models.qwen3_vl.configuration_qwen3_vl import \ - Qwen3VLConfig - from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ - Qwen3VLMoeConfig -except ImportError: - pass -from vllm.config import VllmConfig -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, - get_act_and_mul_fn) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, - Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, - Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, - Qwen2_5_VLProcessingInfo) - -try: - from vllm.model_executor.models.qwen3_vl import ( - Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, - Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) - from vllm.model_executor.models.qwen3_vl_moe import ( - Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) -except ImportError: - Qwen3_VisionBlock = object - Qwen3_VisionPatchEmbed = object - Qwen3_VisionTransformer = object - Qwen3VLDummyInputsBuilder = object - Qwen3VLForConditionalGeneration = object - Qwen3VLMultiModalProcessor = object - Qwen3VLProcessingInfo = object - Qwen3VLMoeForConditionalGeneration = object - Qwen3VLMoeProcessingInfo = object -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.multimodal import MULTIMODAL_REGISTRY - -from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding - - -class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.embed_dim = embed_dim - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1.dev20250226 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=cu_seqlens, - scale_value=self.hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - - -class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock): - - def __init__(self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, - quant_config, prefix) - self.attn = AscendQwen2_5_VisionAttention_Without_Padding( - embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - return x - - -class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer - ): - - def __init__( - self, - vision_config: Qwen2_5_VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - interleaved=False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix) - norm_layer = partial(RMSNorm, eps=norm_eps) - self.interleaved = interleaved - head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // - 2) - self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding( - patch_size=vision_config.patch_size, - temporal_patch_size=vision_config.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - - act_fn = get_act_and_mul_fn(vision_config.hidden_act) - self.blocks = nn.ModuleList([ - AscendQwen2_5_VisionBlock_Without_Padding( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens - - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, - 0]).cpu().to(torch.int32) - - # patchify - x = self.patch_embed(x) - - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=x.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) - seq_len, _ = x.size() - x = x.reshape(seq_len // self.spatial_merge_unit, - self.spatial_merge_unit, -1) - x = x[window_index, :, :] - x = x.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - # transformers - x = x.unsqueeze(1) - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - else: - cu_seqlens_now = cu_window_seqlens - x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) - - # adapter - x = self.merger(x) - reverse_indices = torch.argsort(window_index) - x = x[reverse_indices, :] - return x - - -class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - x = x + self.proj.bias - return x - - -class AscendQwen3_VisionBlock(Qwen3_VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, - quant_config, prefix, use_data_parallel) - self.attn = AscendQwen2_5_VisionAttention_Without_Padding( - embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): - - def __init__( - self, - vision_config, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix, - use_data_parallel) - norm_layer = partial(nn.LayerNorm, eps=norm_eps) - self.patch_embed = AscendQwen3_VisionPatchEmbed( - patch_size=self.patch_size, - temporal_patch_size=self.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - self.blocks = nn.ModuleList([ - AscendQwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def forward( - self, - x: torch.Tensor, - grid_thw: list[list[int]], - ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) - grid_thw_tensor = torch.tensor(grid_thw, - device=self.device, - dtype=torch.int32) - cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], - grid_thw_tensor[:, 0]).cpu().to(torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - deepstack_feature_lists = [] - for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin) - if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index( - layer_num) - deepstack_feature = self.deepstack_merger_list[ - deepstack_merger_idx](hidden_states) - deepstack_feature_lists.append(deepstack_feature) - hidden_states = self.merger(hidden_states) - hidden_states = torch.cat( - [hidden_states] + deepstack_feature_lists, - dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] - return hidden_states - - -@MULTIMODAL_REGISTRY.register_processor( - Qwen2_5_VLMultiModalProcessor, - info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class AscendQwen2_5_VLForConditionalGeneration_Without_Padding( - Qwen2_5_VLForConditionalGeneration): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( - vision_config=config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: - - grid_thw = image_input["image_grid_thw"] - assert grid_thw.ndim == 2 - - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - - # Split concatenated embeddings for each image item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) - - def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: - - grid_thw = video_input["video_grid_thw"] - assert grid_thw.ndim == 2 - - if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) - else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - - # Split concatenated embeddings for each video item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLMoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLMoeForConditionalGeneration( - Qwen3VLMoeForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - ) diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index 7b1ce44a211..f24f9823648 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -38,13 +38,10 @@ Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.models.vision import conv3d_to_linear_weight from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, is_enable_nz, - vllm_version_is) - -if not vllm_version_is("0.11.0"): - from vllm.model_executor.models.vision import conv3d_to_linear_weight +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz MIN_PAD_SIZE = 64 # min_size to pad weight MAX_PAD_SIZE = 128 # max_size to pad weight @@ -308,9 +305,8 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_params: Set[str] = set() for name, loaded_weight in weights: - if not vllm_version_is("0.11.0"): - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 622efe23309..b1d7b5444a9 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -16,8 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import RMSNormGated -from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops import chunk from vllm.model_executor.layers.fla.ops.fused_recurrent import \ fused_recurrent_gated_delta_rule from vllm.model_executor.layers.fused_moe import FusedMoE @@ -25,6 +24,7 @@ # yapf: disable from vllm.model_executor.layers.layernorm import \ GemmaRMSNorm as Qwen3NextRMSNorm +from vllm.model_executor.layers.layernorm import RMSNormGated # yapf: enable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -35,8 +35,7 @@ mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops import causal_conv1d from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -50,8 +49,6 @@ from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm_ascend.utils import vllm_version_is - from vllm.model_executor.models.qwen3_next import ( # isort: skip Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, @@ -183,6 +180,83 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + num_tokens = hidden_states.size(0) + + # ============================================================ + # Part 1: Input Projection + # ============================================================ + + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + + # 1. Set up dimensions for reshapes later + projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) + projected_states_qkvz, projected_states_ba = torch.split( + projected_states, + [ + self.projection_size_qkvz // self.tp_size, + self.projection_size_ba // self.tp_size + ], + dim=-1, + ) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # ============================================================ + # Part 2: Core Attention (Custom Op) + # ============================================================ + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops.vllm.gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + z_shape_og = z.shape + # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + def _forward( self, hidden_states: torch.Tensor, @@ -202,11 +276,8 @@ def _forward( spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks - if vllm_version_is("0.11.0"): - spec_token_masks = attn_metadata.spec_token_masks - else: - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache[forward_context.virtual_engine] @@ -221,9 +292,6 @@ def _forward( # 1. Set up dimensions for reshapes later projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) - if vllm_version_is("0.11.0"): - if spec_token_masks is not None: - spec_token_masks = spec_token_masks[:num_actual_tokens] projected_states_qkvz, projected_states_ba = torch.split( projected_states, [ @@ -248,13 +316,9 @@ def _forward( mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: - if vllm_version_is("0.11.0"): - mixed_qkv_spec = mixed_qkv[spec_token_masks] - mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select( - 0, non_spec_token_indx) + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select( + 0, non_spec_token_indx) else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv @@ -264,7 +328,7 @@ def _forward( mixed_qkv_spec = mixed_qkv_spec.view( attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') - mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec = causal_conv1d.causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, @@ -281,7 +345,7 @@ def _forward( if attn_metadata.num_prefills > 0: # - "cache_indices" updates the conv_state cache in positions # pointed to by "mamba_cache_params.state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn( mixed_qkv_non_spec.transpose(0, 1), conv_weights, self.conv1d.bias, @@ -292,7 +356,7 @@ def _forward( query_start_loc=non_spec_query_start_loc, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, @@ -322,16 +386,10 @@ def _forward( g_non_spec = None beta_non_spec = None else: - if vllm_version_is("0.11.0"): - g_spec = g[:, spec_token_masks] - beta_spec = beta[:, spec_token_masks] - g_non_spec = g[:, ~spec_token_masks] - beta_non_spec = beta[:, ~spec_token_masks] - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) else: g_spec = None beta_spec = None @@ -359,6 +417,227 @@ def _forward( else: core_attn_out_spec, last_recurrent_state = None, None + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk.chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True) + + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + core_attn_out.index_copy_(1, non_spec_token_indx, + core_attn_out_non_spec) + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """ + Core attention computation (called by custom op). + """ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + num_accepted_tokens = attn_metadata.num_accepted_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select( + 0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = mixed_qkv_spec.view( + attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') + mixed_qkv_spec = causal_conv1d.causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0] + [:attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + validate_data=False, + ) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn( + mixed_qkv_non_spec.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + # validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2. Recurrent attention + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + # 3.2: process the remaining part if attn_metadata.num_prefills > 0: initial_state = ssm_state[ @@ -366,7 +645,7 @@ def _forward( initial_state[~has_initial_state, ...] = 0 batch_size = initial_state.shape[0] - core_attn_out = [] + temp_core_attn_out = [] last_recurrent_state = [] for b_idx in range(batch_size): @@ -382,7 +661,7 @@ def _forward( ( cur_core_attn_out_non_spec, cur_last_recurrent_state, - ) = chunk_gated_delta_rule( + ) = chunk.chunk_gated_delta_rule( query=cur_q, key=cur_k, value=cur_v, @@ -393,18 +672,18 @@ def _forward( use_qk_l2norm_in_kernel=True, ) - core_attn_out.append(cur_core_attn_out_non_spec) + temp_core_attn_out.append(cur_core_attn_out_non_spec) last_recurrent_state.append(cur_last_recurrent_state) - tar_dtype = core_attn_out[0].dtype - tar_device = core_attn_out[0].device - tar_shape = list(core_attn_out[0].shape) + tar_dtype = temp_core_attn_out[0].dtype + tar_device = temp_core_attn_out[0].device + tar_shape = list(temp_core_attn_out[0].shape) tar_shape[1] = non_spec_query_start_loc[-1] core_attn_out_non_spec = torch.empty(tar_shape, dtype=tar_dtype, device=tar_device) for b_idx in range(batch_size): - cur_core_attn_out = core_attn_out[b_idx] + cur_core_attn_out = temp_core_attn_out[b_idx] start, end = non_spec_query_start_loc[ b_idx], non_spec_query_start_loc[b_idx + 1] core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out @@ -431,36 +710,22 @@ def _forward( else: core_attn_out_non_spec, last_recurrent_state = None, None - # Merge core attention output - if (spec_sequence_masks is not None - and core_attn_out_non_spec is not None): - core_attn_out = torch.empty( + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) - if vllm_version_is("0.11.0"): - core_attn_out[:, spec_token_masks] = core_attn_out_spec - core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec - else: - core_attn_out.index_copy_(1, spec_token_indx, - core_attn_out_spec) - core_attn_out.index_copy_(1, non_spec_token_indx, - core_attn_out_non_spec) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, + core_attn_out_non_spec) + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) elif spec_sequence_masks is not None: - core_attn_out = core_attn_out_spec + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) else: - core_attn_out = core_attn_out_non_spec - - z_shape_og = z.shape - # reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') - - output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze( + 0) class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index fb1abe66606..4889d2320a7 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -33,10 +33,10 @@ class AscendSiluAndMul(SiluAndMul): def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu - from vllm_ascend.utils import is_310p + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) else: out = torch_npu.npu_swiglu(x) diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/casual_conv1d.py deleted file mode 100644 index 7ddc9cecca3..00000000000 --- a/vllm_ascend/ops/casual_conv1d.py +++ /dev/null @@ -1,539 +0,0 @@ -# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2024, Tri Dao. -# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py -# mypy: ignore-errors - -from typing import Optional, Union - -import torch -import torch.nn.functional as F -import triton -import triton.language as tl - -PAD_SLOT_ID = -1 - - -def causal_conv1d_ref( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - - if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out[..., :(width - 1)].copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return (out, None) if not return_final_states else (out, final_states_out) - - -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID, -): - """ - x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen - sequences are concatenated from left to right for varlen - weight: (dim, width) - bias: (dim,) - query_start_loc: (batch + 1) int32 - The cumulative sequence lengths of the sequences in - the batch, used to index into sequence. prepended by 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), - x.shape=(dim,17) - cache_indices: (batch) int32 - indicates the corresponding state index, - like so: conv_state = conv_states[cache_indices[batch_id]] - has_initial_state: (batch) bool - indicates whether should the kernel take the current state as initial - state for the calculations - conv_states: (...,dim,width - 1) itype - updated inplace if provided - activation: either None or "silu" or "swish" - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(-1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - - out_ref = [] - out_ref_b = [] - seqlens = query_start_loc[1:] - query_start_loc[:-1] - seqlens = seqlens.tolist() - splits = torch.split(x, seqlens, dim=-1) - - for i in range(len(seqlens)): - x_s = splits[i] - if cache_indices[i] == PAD_SLOT_ID: - continue - out_ref_b.append( - causal_conv1d_ref( - x_s, - weight, - bias, - activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]].unsqueeze(0), - initial_states=conv_states[cache_indices[i]] - if has_initial_state[i] else None)) - out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) - out_ref_tensor = torch.cat(out_ref, dim=0) - return out_ref_tensor - - -@triton.jit() -def _causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - cache_seqlens_ptr, # circular buffer - conv_state_indices_ptr, - num_accepted_tokens_ptr, - intermediate_conv_window_ptr, - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_inter_seq: tl.constexpr, - stride_inter_step: tl.constexpr, - stride_inter_dim: tl.constexpr, - stride_inter_win: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, - SAVE_INTERMEDIATE: tl.constexpr, -): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return - - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) - else: - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + - idx_seq) - 1 - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) - mask_w = idx_feats < dim - - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok - if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # The conv_state updates works in a sliding window manner, - # at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + 1) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] - - x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = (conv_state_base + - (idx_tokens * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) - - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim - - # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - # mask_1d = (idx_token < seqlen) & ( - # idx_feats < dim - # ) # token-index # feature-index - maskL = idx_feats < dim - maskR = tl.full(maskL.shape, False, tl.int1) - mask_1d = tl.where(idx_token < seqlen, maskL, maskR) - - o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + - idx_token * stride_o_token + (idx_feats * stride_o_dim)) - - tl.store(o_ptrs, acc, mask=mask_1d) - - if SAVE_INTERMEDIATE: - # Save the window state after consuming this token - # Layout: [seq(cache line), step, dim, win(K-1)] - base_ptr = (intermediate_conv_window_ptr + - conv_state_batch_coord * stride_inter_seq + - idx_token * stride_inter_step + - idx_feats * stride_inter_dim) - if KERNEL_WIDTH >= 2: - tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) - if KERNEL_WIDTH >= 3: - tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) - if KERNEL_WIDTH >= 4: - tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) - - -def causal_conv1d_update_npu( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - intermediate_conv_window: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID, - metadata=None, - validate_data=False, -): - """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] - conv_state: (..., dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) - """ - if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM - assert pad_slot_id is not None - assert x.stride(1) == 1 - if isinstance(activation, bool): - activation = "silu" if activation is True else None - elif activation is not None: - assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert ( - conv_state.stride(-2) == 1 - ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch, ) == conv_state_indices.shape - - assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x - stride_w_dim, stride_w_width = weight.stride() - - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( - ) - stride_state_indices = (conv_state_indices.stride(0) - if conv_state_indices is not None else 0) - state_len = width - 1 + (seqlen - 1) # effective state_len needed - np2_statelen = triton.next_power_of_2(state_len) - - def grid(META): - return ( - batch, - triton.cdiv(dim, META["BLOCK_N"]), - ) - - # prepare intermediate buffer strides if provided - if intermediate_conv_window is not None: - stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( - intermediate_conv_window.stride(0), - intermediate_conv_window.stride(1), - intermediate_conv_window.stride(2), - intermediate_conv_window.stride(3), - ) - else: - stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 - - _causal_conv1d_update_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_state, - cache_seqlens, - conv_state_indices, - num_accepted_tokens, - intermediate_conv_window - if intermediate_conv_window is not None else x, - out, - # Matrix dimensions - batch, - dim, - seqlen, - state_len, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_state_indices, - stride_inter_seq, - stride_inter_step, - stride_inter_dim, - stride_inter_win, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=128, - SAVE_INTERMEDIATE=intermediate_conv_window is not None, - ) - if unsqueeze: - out = out.squeeze(-1) - return out diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py index 604986b4103..7e8a9aefd28 100644 --- a/vllm_ascend/ops/expert_load_balancer.py +++ b/vllm_ascend/ops/expert_load_balancer.py @@ -8,12 +8,14 @@ class ExpertLoadBalancer(object): - def __init__(self, expert_map_path, global_expert_num): + def __init__(self, expert_map_path, num_experts): self.expert_map_path = expert_map_path - self.global_expert_num = global_expert_num + self.num_experts = num_experts self.tensor_data = [] self.expert_map_tensor, self.layers_num, self.ranks_num = ( self._expert_file_to_tensor()) + self.global_expert_num = num_experts + self.get_global_redundant_expert_num( + ) self.expert_placement_map = self.generate_expert_placement_map() def _expert_file_to_tensor(self): @@ -95,7 +97,7 @@ def get_rank_log2phy_map(self, layer_id, rank_id): def get_global_redundant_expert_num(self): global_redundant_expert_num = ( len(self.expert_map_tensor[0][0]) * self.ranks_num - - self.global_expert_num) + self.num_experts) return global_redundant_expert_num def check_expert_map_tensor(self): diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py deleted file mode 100644 index 79039002d1f..00000000000 --- a/vllm_ascend/ops/fla.py +++ /dev/null @@ -1,299 +0,0 @@ -# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py -# Copyright (c) 2024, Tri Dao. -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. -# mypy: ignore-errors - -import torch -import torch.nn.functional as F -from vllm.triton_utils import tl, triton - -MAX_CORES = 65535 - - -@triton.heuristics({ - "HAS_BIAS": lambda args: args["B"] is not None, - "HAS_Z": lambda args: args["Z"] is not None, -}) -@triton.jit -def layer_norm_fwd_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - Z, # pointer to the other branch - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_z_row, - M, # number of rows in X_base - N, # number of columns in X_base - eps, # epsilon to avoid division by zero - BLOCK_N: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_Z: tl.constexpr, - NORM_BEFORE_GATE: tl.constexpr, - IS_RMS_NORM: tl.constexpr, - N_CORES: tl.constexpr, -): - # Map the program id to the row of X_base and Y_base it should compute. - row = tl.program_id(0) - group = tl.program_id(1) - - BLOCK_ROWS = M if M < N_CORES else N_CORES - n_iters = M // BLOCK_ROWS - remain = M % BLOCK_ROWS - if row < remain: - n_iters = n_iters + 1 - - for i in tl.range(n_iters): - X_base = X + (i * BLOCK_ROWS * - stride_x_row) + row * stride_x_row + group * N - Y_base = Y + (i * BLOCK_ROWS * - stride_y_row) + row * stride_y_row + group * N - if HAS_Z: - Z_base = Z + (i * BLOCK_ROWS * - stride_z_row) + row * stride_z_row + group * N - if not IS_RMS_NORM: - Mean_base = Mean + (i * BLOCK_ROWS) + group * M - Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M - W_base = W + group * N - if HAS_BIAS: - B_base = B + group * N - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32) - if HAS_Z and not NORM_BEFORE_GATE: - z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32) - x *= z * tl.sigmoid(z) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean_base + row, mean) - xbar = tl.where(cols < N, x - mean, 0.) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd_base + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W_base + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B_base + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - if HAS_Z and NORM_BEFORE_GATE: - z = tl.load(Z_base + cols, mask=mask).to(tl.float32) - y *= z * tl.sigmoid(z) - # Write output - tl.store(Y_base + cols, y, mask=mask) - - -def _layer_norm_fwd( - x, - weight, - bias, - eps, - z=None, - out=None, - group_size=None, - norm_before_gate=True, - is_rms_norm=False, -): - M, N = x.shape - if group_size is None: - group_size = N - assert N % group_size == 0 - ngroups = N // group_size - assert x.stride(-1) == 1 - if z is not None: - assert z.stride(-1) == 1 - assert z.shape == (M, N) - assert weight.shape == (N, ) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N, ) - # allocate output - if out is not None: - assert out.shape == x.shape - else: - out = torch.empty_like(x) - assert out.stride(-1) == 1 - mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) - if not is_rms_norm else None) - rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) - if group_size > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_N // 256, 1), 8) - grid = (M if M < MAX_CORES else MAX_CORES, ngroups) - with torch.npu.device(x.device.index): - layer_norm_fwd_kernel[grid]( - x, - out, - weight, - bias, - z, - mean, - rstd, - x.stride(0), - out.stride(0), - z.stride(0) if z is not None else 0, - M, - group_size, - eps, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - N_CORES=MAX_CORES, - num_warps=num_warps, - ) - return out, mean, rstd - - -class LayerNormFn(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False, - ): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if z is not None: - assert z.shape == x_shape_og - z = z.reshape(-1, z.shape[-1]) - if z.stride(-1) != 1: - z = z.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - y, mean, rstd = _layer_norm_fwd( - x, - weight, - bias, - eps, - z=z, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=is_rms_norm, - ) - return y.reshape(x_shape_og) - - -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - chunk_size=64, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, sequence_length, num_heads, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - tot_heads = num_heads + pad_size - scale = 1 / (query.shape[-1]**0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=0) - - # chunk decay - g = g.cumsum(dim=-1) - decay_mask = ((g.unsqueeze(-1) - - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -( - (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - - last_recurrent_state = (torch.zeros(batch_size, sequence_length, - k_head_dim, v_head_dim).to(value) if - initial_state is None else initial_state.to(value)) - - core_attn_out = torch.zeros_like(value) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=1) - - # for each chunk - for i in range(0, tot_heads // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * - decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() + - (k_i * - (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2) @ v_new) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], - core_attn_out.shape[1], -1, - core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :num_heads] - core_attn_out = core_attn_out.transpose(1, - 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index e511d6b554f..eb3fc848c8e 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -20,8 +20,6 @@ import torch_npu from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_config import get_ascend_config - def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - topk_weights, topk_ids = _select_experts_with_fusion_ops( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) - - if topk_weights is None: + if custom_routing_function is None: + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops( e_score_correction_bias: Optional[torch.Tensor], topk_group: Optional[int], num_expert_group: Optional[int], - custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1): - topk_weights, topk_ids = None, None - # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern - global_redundant_expert_num = get_ascend_config().init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) + if scoring_func == "softmax": + norm_type = 0 + topk_group = 1 + num_expert_group = 1 + else: + norm_type = 1 + if e_score_correction_bias is not None and \ + e_score_correction_bias.dtype != router_logits.dtype: + e_score_correction_bias = e_score_correction_bias.to( + router_logits.dtype) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=norm_type, # 0: softmax; 1: sigmoid + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 113cd47e891..b9667abbccb 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -28,12 +28,13 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map, get_compressed_expert_map) +from vllm.model_executor.layers.fused_moe.shared_fused_moe import \ + SharedFusedMoE from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, - determine_default_log2phy_map) +from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method @@ -42,20 +43,10 @@ AscendW4A8DynamicFusedMoEMethod from vllm_ascend.quantization.w8a8_dynamic import \ AscendW8A8DynamicFusedMoEMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, - is_enable_nz, npu_stream_switch, - shared_expert_dp_enabled, - shared_experts_calculation_stream, - vllm_version_is) - -if vllm_version_is("0.11.0"): - from vllm.config import CompilationLevel - - from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE # type: ignore # isort:skip -else: - from vllm.config import CompilationMode - from vllm.model_executor.layers.fused_moe.shared_fused_moe import \ - SharedFusedMoE +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, + enable_sp, get_ascend_device_type, is_enable_nz, + npu_stream_switch, shared_expert_dp_enabled, + shared_experts_calculation_stream) class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): @@ -63,28 +54,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, moe: FusedMoEConfig = None): super().__init__(moe=moe) - - # NOTE: Currently, this self.use_aclgraph is only used in - # UnquantizedFusedMoEMethod.forward_oot to decide whether to use in - # ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue. - # Once torch.randint_like is supported or removed, this flag can be removed. - vllm_config = get_current_vllm_config() - ascend_config = get_ascend_config() self.dynamic_eplb = get_ascend_config().dynamic_eplb - if ascend_config.torchair_graph_config.enabled: - self.use_aclgraph = False - else: - if vllm_version_is("0.11.0"): - self.use_aclgraph = ( - vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and not vllm_config.model_config.enforce_eager) - else: - self.use_aclgraph = ( - vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE - and not vllm_config.model_config.enforce_eager) - self.transpose = True def process_weights_after_loading(self, layer): @@ -109,7 +79,8 @@ def process_weights_after_loading(self, layer): w2_data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - if not is_310p() and is_enable_nz(): + if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz( + ): layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast( @@ -153,7 +124,7 @@ def apply(self, # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. - if enable_force_load_balance and not self.use_aclgraph: + if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) moe_comm_method = get_forward_context().moe_comm_method @@ -183,10 +154,8 @@ def __init__(self, *args, **kwargs): AscendFusedMoE.moe_counter += 1 self.moe_instance_id = AscendFusedMoE.moe_counter - self.global_num_experts = num_experts self.expert_map = None self.log2phy = None - self.global_redundant_expert_num = 0 if self.quant_config is None: self.quant_method = AscendUnquantizedFusedMoEMethod( @@ -210,15 +179,20 @@ def __init__(self, *args, **kwargs): vllm_config = get_current_vllm_config() self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( dtype=vllm_config.model_config.dtype) + + # init moe. + self.local_num_experts, self.expert_map, _ = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) # static eplb initializing with expert_map_path if self.expert_map_path and os.path.exists( self.expert_map_path) and os.access(self.expert_map_path, os.R_OK): self.expert_load_balancer = ExpertLoadBalancer( - self.expert_map_path, self.global_num_experts) + self.expert_map_path, num_experts) self.expert_load_balancer.check_expert_map_tensor() self.global_redundant_expert_num = ( self.expert_load_balancer.get_global_redundant_expert_num()) + self.global_num_experts = num_experts + self.global_redundant_expert_num try: self.local_num_experts, self.expert_map = ( self.expert_load_balancer.get_rank_placement_map( @@ -228,45 +202,21 @@ def __init__(self, *args, **kwargs): except Exception as e: logger.warning( f"Init expert map of mtp/eagle when using sample.{e}") - self.local_num_experts, self.expert_map = determine_default_expert_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num) self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num).npu() - if self.expert_map is not None and isinstance( - self.expert_map, torch.Tensor): - logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" - " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, - self.global_num_experts, - get_compressed_expert_map(self.expert_map)) + self.global_num_experts, self.ep_size, self.ep_rank).npu() else: - # init moe. - if vllm_version_is("0.11.0"): - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) - else: - self.local_num_experts, self.expert_map, _ = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) # dynamic eplb initializing with not expert_map_path if self.dynamic_eplb: - self.global_redundant_expert_num = ascend_config.init_redundancy_expert - self.local_num_experts, self.expert_map = determine_default_expert_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num) self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num).npu() - if self.expert_map is not None and isinstance( - self.expert_map, torch.Tensor): - logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" - " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, - self.global_num_experts, - get_compressed_expert_map(self.expert_map)) + self.global_num_experts, self.ep_size, self.ep_rank).npu() + if self.expert_map is not None and isinstance(self.expert_map, + torch.Tensor): + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) local_num_experts = (torch.sum( self.expert_map != -1) if self.expert_map is not None else self.global_num_experts) @@ -274,6 +224,12 @@ def __init__(self, *args, **kwargs): self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64).npu() + eplb_enable = self.dynamic_eplb or (self.expert_map_path is not None) + if eplb_enable and (not hasattr(self.quant_method, "quant_method") or + not isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod)): + raise ValueError("Eplb supports only w8a8_dynamic quantization.") + self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts self.moe_config.original_num_experts = num_experts @@ -489,6 +445,13 @@ def gate(self) -> Optional[torch.nn.Module]: def is_internal_router(self) -> bool: return False + @property + def use_dp_chunking(self) -> bool: + """This func routes to the chunked forward path using the FlashInfer Cutlass kernel + only when data parallelism (DP) is enabled. Thus just returning False in vllm-ascend + """ + return False + def forward( self, hidden_states: torch.Tensor, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index c89eb1df3ab..c48ce1a49be 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -27,7 +27,7 @@ from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, - PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast, QuantType) + PrepareAndFinalizeWithMC2, QuantType) from vllm_ascend.ops.fused_moe.token_dispatcher import ( TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2, TokenDispatcherWithMoge) @@ -44,8 +44,6 @@ def setup_moe_comm_method(moe_config): _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) - _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( - moe_config) class MoECommMethod(ABC): @@ -245,32 +243,3 @@ def _get_token_dispatcher(self): def _get_prepare_finalize(self): return PrepareAndFinalizeWithAll2All(self.moe_config) - - -class NaiveMulticastCommImpl(MoECommMethod): - """This implementation is the same as NativeAllGatherCommImpl, - but uses NPU-specific ops for better performance. - - This implementation should be compatible with all scenarios, and - thus it is the default implementation for MoE communication methods. - It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing - and `torch_npu.npu_moe_token_unpermute` for post-processing - to handle the token-to-expert mapping and communication efficiently. - - NOTE(Yizhou): TBH, it is really weird that we were supposed to use - `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` - or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` - for pre-processing and post-processing, respectively. - But `npu_moe_finalize_routing` will lead to accuracy issues so we have to - use `torch_npu.npu_moe_token_unpermute` instead. - This is a workaround and should be removed after the issue is fixed. - """ - - def _get_token_dispatcher(self): - return TokenDispatcherWithAllGather( - top_k=self.moe_config.experts_per_token, - num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) - - def _get_prepare_finalize(self): - return PrepareAndFinalizeWithNaiveMulticast(self.moe_config) diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 0e2b81fb64c..07ba732f199 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -22,7 +22,8 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_forward_context import MoECommType -from vllm_ascend.utils import dispose_tensor, is_310p +from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, + get_ascend_device_type) def cumsum_group_list(group_list: torch.Tensor, @@ -210,7 +211,7 @@ def unquant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, )[0] - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( torch.float16) else: diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 46640006291..48350ea80c9 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -45,7 +45,7 @@ class PrepareAndFinalize(ABC): """ Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization in distributed environments. Subclasses implement specific communication strategies - (e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing, + (e.g., AllGather, All2All, MC2) to handle tensor padding, slicing, broadcasting, and reduction across TP/DP/EP groups. Attributes: @@ -454,115 +454,3 @@ def _finalize_with_dp_group(self, hidden_states: torch.Tensor, hidden_states = tensor_model_parallel_all_reduce(hidden_states) return hidden_states - - -class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): - """ - MoE communication strategy using Naive Multicast (point-to-point broadcast). - Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others. - Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries. - """ - - def _naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - """ - Naive multicast implementation: - 1. Create global buffer sized by total tokens across DP. - 2. Current rank copies its slice into its designated buffer region. - 3. Each rank broadcasts its slice to all others via P2P. - - Args: - x (torch.Tensor): Local tensor [local_tokens, hidden_size] - cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank - - Returns: - torch.Tensor: Global tensor [total_tokens, hidden_size] - """ - assert len(x.shape) == 2, "Input must be 2D [tokens, features]" - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - - # Copy local slice into buffer - start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.moe_config.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank] - buffer[start:end, :].copy_(x) - - # Broadcast each slice to all ranks - for idx in range(self.moe_config.dp_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - return buffer - - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: - """ - Preparation steps: - 1. Fetch cumulative token boundaries from forward context. - 2. Multicast hidden_states and router_logits to form global tensors. - - Returns: - Tuple of (global_hidden_states, global_router_logits, None, None) - """ - self.enable_shared_expert_dp = enable_shared_expert_dp - - if self.moe_config.dp_size > 1: - self.cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_sp(1) - hidden_states = self._naive_multicast(hidden_states, - self.cu_tokens_across_dp_cpu) - router_logits = self._naive_multicast(router_logits, - self.cu_tokens_across_dp_cpu) - - if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: - hidden_states = get_pcp_group().all_gather( - hidden_states, - dim=0, - ) - router_logits = get_pcp_group().all_gather( - router_logits, - dim=0, - ) - - return hidden_states, router_logits, None, None - - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: - """ - Finalization steps: - 1. If DP > 1 and not shared expert: - - All-reduce across DP - - Slice to current rank's token range using cu_tokens_across_dp_cpu - 2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. - - Returns: - Tensor with shape [local_num_tokens, hidden_size] - """ - if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: - start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[ - self.moe_config.dp_rank - 1] - end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank] - hidden_states = get_dp_group().all_reduce( - hidden_states) # Sum across DP - hidden_states = hidden_states[start:end, :] - - if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: - hidden_states = get_pcp_group().reduce_scatter(hidden_states, - dim=0) - - if reduce_results and (self.moe_config.tp_size > 1 - or self.moe_config.ep_size > 1): - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - return hidden_states diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 1ef06533810..57f26046072 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -30,7 +30,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) -from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version, +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled) @@ -98,11 +98,11 @@ def __init__(self, **kwargs): self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") self.need_extra_args = ( - get_ascend_soc_version() == AscendSocVersion.A3) + get_ascend_device_type() == AscendDeviceType._910_93) # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine self.a3_need_extra_args = \ - get_ascend_soc_version() == AscendSocVersion.A3 + get_ascend_device_type() == AscendDeviceType._910_93 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 6b89f4a5c71..8c395b54fd4 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -32,9 +32,10 @@ def _addrmsnorm_forward_oot( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu - from vllm_ascend.utils import is_310p + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - if layer is not None and not is_310p(): + if layer is not None and get_ascend_device_type( + ) != AscendDeviceType._310P: layer_cls_name = layer.__class__.__name__ try: weight_prefetch_method = get_forward_context( @@ -67,7 +68,7 @@ def _addrmsnorm_forward_oot( ) else: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: orig_dtype = residual.dtype x = x + residual.to(x.dtype) residual = x.to(orig_dtype) @@ -195,9 +196,9 @@ def forward_oot( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu - from vllm_ascend.utils import is_310p + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type if residual is not None: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: orig_dtype = residual.dtype x = x + residual.to(x.dtype) residual = x.to(orig_dtype) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index eab312d5cf8..844cdcbde72 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -45,7 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if (is_enable_nz() and layer.weight.data.dtype + if "conv1d" not in layer.prefix and ( + is_enable_nz() and layer.weight.data.dtype in [torch.float16, torch.bfloat16]): layer.weight.data = torch_npu.npu_format_cast( layer.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 6a3057d9581..bb16bc006a4 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -7,17 +7,12 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context +from vllm.utils.torch_utils import direct_register_custom_op import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch -from vllm_ascend.utils import (npu_stream_switch, prefetch_stream, - vllm_version_is) - -if vllm_version_is("0.11.0"): - from vllm.utils import direct_register_custom_op -else: - from vllm.utils.torch_utils import direct_register_custom_op +from vllm_ascend.utils import npu_stream_switch, prefetch_stream def _maybe_all_gather_and_maybe_unpad_impl( diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 098945576e7..91a6f09fa1a 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -24,9 +24,11 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) +from vllm.platforms import CpuArchEnum from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import enable_custom_op, is_310p +from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, + get_ascend_device_type) def _custom_rotary_embedding_enabled(query, neox_style, head_size): @@ -48,8 +50,9 @@ def _rope_forward_oot( if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) # adopt custom kernel path for rotary_embedding - if _custom_rotary_embedding_enabled(query, is_neox_style, - self.head_size) and not is_310p(): + if _custom_rotary_embedding_enabled( + query, is_neox_style, self.head_size) and get_ascend_device_type( + ) != AscendDeviceType._310P: query, key = torch.ops._C_ascend.rotary_embedding( positions, query, @@ -405,7 +408,10 @@ def forward_oot( query: torch.Tensor, key: torch.Tensor, ): - if self.mrope_section != [16, 24, 24]: + # TODO: This judgment will be removed once the mrope precision issue is fixed + if self.mrope_section != [ + 16, 24, 24 + ] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86: return super().forward_oot(positions, query, key) import torch_npu @@ -428,4 +434,4 @@ def forward_oot( mrope_section=mrope_section, rotary_mode='half') - return query, key \ No newline at end of file + return query, key diff --git a/vllm_ascend/ops/sigmoid_gating.py b/vllm_ascend/ops/sigmoid_gating.py deleted file mode 100644 index 39e653a5913..00000000000 --- a/vllm_ascend/ops/sigmoid_gating.py +++ /dev/null @@ -1,300 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -# mypy: ignore-errors - -import os - -from vllm.triton_utils import tl, tldevice, triton - -if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': - div = tldevice.fast_dividef - exp = tldevice.fast_expf - log = tldevice.fast_logf - log2 = tldevice.fast_log2f -else: - - @triton.jit - def div_normal(x, y): - return x / y - - div = div_normal - exp = tl.exp - log = tl.log - log2 = tl.log2 - - -@triton.heuristics({ - 'USE_INITIAL_STATE': - lambda args: args['h0'] is not None, - 'IS_VARLEN': - lambda args: args['cu_seqlens'] is not None, - "IS_CONTINUOUS_BATCHING": - lambda args: args['ssm_state_indices'] is not None, - "IS_SPEC_DECODING": - lambda args: args['num_accepted_tokens'] is not None, -}) -@triton.jit(do_not_specialize=['N', 'T']) -def fused_recurrent_gated_delta_rule_fwd_kernel( - q, - k, - v, - g, - beta, - o, - h0, - ht, - cu_seqlens, - ssm_state_indices, - num_accepted_tokens, - scale, - N: tl.constexpr, # num of sequences - T: tl.constexpr, # num of tokens - B: tl.constexpr, - H: tl.constexpr, - HV: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - stride_init_state_token: tl.constexpr, - stride_final_state_token: tl.constexpr, - stride_indices_seq: tl.constexpr, - stride_indices_tok: tl.constexpr, - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state - INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace - IS_BETA_HEADWISE: tl. - constexpr, # whether beta is headwise vector or scalar, - USE_QK_L2NORM_IN_KERNEL: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - IS_KDA: tl.constexpr, -): - i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_n, i_hv = i_nh // HV, i_nh % HV - i_h = i_hv // (HV // H) - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) - all = T - T = eos - bos - else: - bos, eos = i_n * T, i_n * T + T - all = B * T - - if T == 0: - # no tokens to process for this sequence - return - - o_k = i_k * BK + tl.arange(0, BK) - o_v = i_v * BV + tl.arange(0, BV) - - mask_k = o_k < K - mask_v = o_v < V - mask_h = mask_k[:, None] & mask_v[None, :] - - b_h = tl.zeros([BK, BV], dtype=tl.float32) - if USE_INITIAL_STATE: - if IS_CONTINUOUS_BATCHING: - if IS_SPEC_DECODING: - i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 - else: - i_t = 0 - p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_init_state_token - else: - p_h0 = h0 + bos * HV * K * V - p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] - b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) - - for i_t in range(0, T): - p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t - p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t - p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t - - if IS_BETA_HEADWISE: - p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t - else: - p_beta = beta + bos * HV + i_hv + HV * i_t - - if not IS_KDA: - p_g = g + bos * HV + i_hv + HV * i_t - else: - p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k - - p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t - - b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) - b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) - b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) - b_g = tl.load(p_g).to(tl.float32) - - if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) - b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) - b_q = b_q * scale - # [BK, BV] - # b_h *= tl.exp(b_g) - if not IS_KDA: - b_g = tl.load(p_g).to(tl.float32) - b_h *= exp(b_g) - else: - b_gk = tl.load(p_gk).to(tl.float32) - b_h *= exp(b_gk[:, None]) - # [BV] - b_v -= tl.sum(b_h * b_k[:, None], 0) - if IS_BETA_HEADWISE: - b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) - else: - b_beta = tl.load(p_beta).to(tl.float32) - b_v *= b_beta - # [BK, BV] - b_h += b_k[:, None] * b_v[None, :] - # [BV] - b_o = tl.sum(b_h * b_q[:, None], 0) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) - - # keep the states for multi-query tokens - if INPLACE_FINAL_STATE: - p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_final_state_token - else: - p_ht = ht + (bos + i_t) * stride_final_state_token - p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] - tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - - -@triton.heuristics({ - 'USE_INITIAL_STATE': - lambda args: args['h0'] is not None, - 'IS_VARLEN': - lambda args: args['cu_seqlens'] is not None, - "IS_CONTINUOUS_BATCHING": - lambda args: args['ssm_state_indices'] is not None, - "IS_SPEC_DECODING": - lambda args: args['num_accepted_tokens'] is not None, -}) -@triton.jit(do_not_specialize=['N', 'T']) -def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0( - q, - k, - v, - g, - beta, - o, - h0, - ht, - cu_seqlens, - ssm_state_indices, - num_accepted_tokens, - scale, - N: tl.constexpr, # num of sequences - T: tl.constexpr, # num of tokens - B: tl.constexpr, - H: tl.constexpr, - HV: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - stride_init_state_token: tl.constexpr, - stride_final_state_token: tl.constexpr, - stride_indices_seq: tl.constexpr, - stride_indices_tok: tl.constexpr, - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state - INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace - IS_BETA_HEADWISE: tl. - constexpr, # whether beta is headwise vector or scalar, - USE_QK_L2NORM_IN_KERNEL: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, -): - i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_n, i_hv = i_nh // HV, i_nh % HV - i_h = i_hv // (HV // H) - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) - all = T - T = eos - bos - else: - bos, eos = i_n * T, i_n * T + T - all = B * T - - if T == 0: - # no tokens to process for this sequence - return - - o_k = i_k * BK + tl.arange(0, BK) - o_v = i_v * BV + tl.arange(0, BV) - - mask_k = o_k < K - mask_v = o_v < V - mask_h = mask_k[:, None] & mask_v[None, :] - - b_h = tl.zeros([BK, BV], dtype=tl.float32) - if USE_INITIAL_STATE: - if IS_CONTINUOUS_BATCHING: - if IS_SPEC_DECODING: - i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 - else: - i_t = 0 - p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_init_state_token - else: - p_h0 = h0 + bos * HV * K * V - p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] - b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) - - for i_t in range(0, T): - p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t - p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t - p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t - if IS_BETA_HEADWISE: - p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t - else: - p_beta = beta + bos * HV + i_hv + HV * i_t - p_g = g + bos * HV + i_hv + HV * i_t - p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t - - b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) - b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) - b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) - b_g = tl.load(p_g).to(tl.float32) - - if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) - b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) - b_q = b_q * scale - # [BK, BV] - # b_h *= tl.exp(b_g) - b_h *= exp(b_g) - # [BV] - b_v -= tl.sum(b_h * b_k[:, None], 0) - if IS_BETA_HEADWISE: - b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) - else: - b_beta = tl.load(p_beta).to(tl.float32) - b_v *= b_beta - # [BK, BV] - b_h += b_k[:, None] * b_v[None, :] - # [BV] - b_o = tl.sum(b_h * b_q[:, None], 0) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) - - # keep the states for multi-query tokens - if INPLACE_FINAL_STATE: - p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_final_state_token - else: - p_ht = ht + (bos + i_t) * stride_final_state_token - p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] - tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 575d3acfa9a..1b346de6744 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -104,29 +104,7 @@ # Future Plan: # Remove this patch when vllm merged them. # -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs` -# Why: -# We need to patch gather_logprobs to make sure call batched_count_greater_than -# with backend=current_platform.simple_compile_backend -# How: -# Patch gather_logprobs call new batched_count_greater_than -# Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm/pull/21591 -# Future Plan: -# Revert it when vLLM merge #21591 and release new version -# ** File: worker/patch_logits.py ** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm._custom_ops.apply_repetition_penalties` -# Why: -# apply_repetition_penalties in vLLM use tensor.is_cuda to check if tensor is on cuda. But the value is always True -# on ascend, thus we need to patch apply_repetition_penalties. -# How: -# Remove the related cuda check in apply_repetition_penalties. -# Related PR (if no, explain why): -# - this is a bug by Ascend only. It can' be fixed in vLLM. -# Future Plan: -# Fix this bug in torch-npu, bump torch-npu version and remove this patch. +# ** File: worker/patch_roberta.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.model_executor.models.roberta.RobertaEmbedding.forward` # Why: diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index b4ef6332bab..8e0a71ab667 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -18,9 +18,10 @@ import vllm_ascend.patch.platform.patch_config # noqa import vllm_ascend.patch.platform.patch_distributed # noqa +import vllm_ascend.patch.platform.patch_dynamo_vllm_backend # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa -if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv( +if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv( "EXPERT_MAP_RECORD", "false") == "true": import vllm_ascend.patch.platform.patch_multiproc_executor # noqa diff --git a/vllm_ascend/patch/platform/patch_distributed.py b/vllm_ascend/patch/platform/patch_distributed.py index 67d4797f9b0..467cc0450be 100644 --- a/vllm_ascend/patch/platform/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_distributed.py @@ -21,7 +21,7 @@ import vllm.envs as envs_vllm from vllm.config import ParallelConfig -from vllm_ascend.utils import is_310p +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type def parallel_config_get_dp_port(self) -> int: @@ -111,5 +111,5 @@ def all_reduce( torch.distributed.distributed_c10d.all_reduce) -if is_310p(): +if get_ascend_device_type() == AscendDeviceType._310P: communication_adaptation_310p() diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index 1c35106e7bf..18939b0fe0d 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -3,23 +3,10 @@ from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.config import MambaModelConfig - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv - +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -else: - from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE - @classmethod def verify_and_update_config(cls, vllm_config) -> None: diff --git a/vllm_ascend/patch/platform/patch_multiproc_executor.py b/vllm_ascend/patch/platform/patch_multiproc_executor.py index ac821e0ef0c..351de417cd7 100644 --- a/vllm_ascend/patch/platform/patch_multiproc_executor.py +++ b/vllm_ascend/patch/platform/patch_multiproc_executor.py @@ -1,31 +1,24 @@ import threading import weakref -from concurrent.futures import ThreadPoolExecutor +from collections import deque +from collections.abc import Callable from multiprocessing.synchronize import Lock as LockType -from typing import Optional import vllm.v1.executor.multiproc_executor from vllm import envs from vllm.config import VllmConfig -from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) +from vllm.utils.network_utils import (get_distributed_init_method, + get_loopback_ip, get_open_port) +from vllm.utils.system_utils import get_mp_context from vllm.v1.executor.abstract import FailureCallback from vllm.v1.executor.multiproc_executor import ( - MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc, + FutureWrapper, MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc, set_multiprocessing_worker_envs) -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import (get_distributed_init_method, get_loopback_ip, - get_mp_context, get_open_port) -else: - from vllm.utils.network_utils import (get_distributed_init_method, - get_loopback_ip, get_open_port) - from vllm.utils.system_utils import get_mp_context - class AscendMultiprocExecutor(MultiprocExecutor): - supports_pp: bool = True def _init_executor(self) -> None: # Call self.shutdown at exit to clean up @@ -33,10 +26,14 @@ def _init_executor(self) -> None: self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False self.shutdown_event = threading.Event() - self.failure_callback: Optional[FailureCallback] = None - self.io_thread_pool: Optional[ThreadPoolExecutor] = None + self.failure_callback: FailureCallback | None = None self.world_size = self.parallel_config.world_size + assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( + f"global world_size ({self.parallel_config.world_size}) must be " + f"divisible by nnodes_within_dp " + f"({self.parallel_config.nnodes_within_dp}). ") + self.local_world_size = self.parallel_config.local_world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size assert self.world_size == tensor_parallel_size * pp_parallel_size, ( @@ -52,27 +49,36 @@ def _init_executor(self) -> None: # get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( get_loopback_ip(), get_open_port()) - + self.rpc_broadcast_mq: MessageQueue | None = None + scheduler_output_handle: Handle | None = None # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs - max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue(self.world_size, - self.world_size, - max_chunk_bytes=max_chunk_bytes) - scheduler_output_handle = self.rpc_broadcast_mq.export_handle() - + if self.parallel_config.node_rank_within_dp == 0: + # For leader node within each dp rank, + # each dp will have its own leader multiproc executor. + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + self.rpc_broadcast_mq = MessageQueue( + self.world_size, + self.local_world_size, + max_chunk_bytes=max_chunk_bytes, + connect_ip=self.parallel_config.master_addr, + ) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers context = get_mp_context() shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: - for rank in range(self.world_size): + global_start_rank = (self.local_world_size * + self.parallel_config.node_rank_within_dp) + for local_rank in range(self.local_world_size): + global_rank = global_start_rank + local_rank unready_workers.append( AscendWorkerProc.make_worker_process( vllm_config=self.vllm_config, - local_rank=rank, - rank=rank, + local_rank=local_rank, + rank=global_rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, @@ -80,15 +86,38 @@ def _init_executor(self) -> None: # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. - self.workers = WorkerProc.wait_for_ready(unready_workers) + + # Wait for all local workers to be ready. + self.workers = AscendWorkerProc.wait_for_ready(unready_workers) + + # Start background thread to monitor worker health if not in headless mode. + if self.monitor_workers: + self.start_worker_monitor() + + self.response_mqs = [] + # Only leader node have remote response mqs + if self.parallel_config.node_rank_within_dp == 0: + for rank in range(self.world_size): + if rank < self.local_world_size: + local_message_queue = self.workers[ + rank].worker_response_mq + assert local_message_queue is not None + self.response_mqs.append(local_message_queue) + else: + remote_message_queue = self.workers[ + 0].peer_worker_response_mqs[rank] + assert remote_message_queue is not None + self.response_mqs.append(remote_message_queue) # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc. - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() - self.start_worker_monitor() + # Wait for all input mqs to be ready. + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.wait_until_ready() + # Wait for all remote response mqs to be ready. + for response_mq in self.response_mqs: + response_mq.wait_until_ready() success = True finally: if not success: @@ -100,17 +129,9 @@ def _init_executor(self) -> None: self._ensure_worker_termination( [uw.proc for uw in unready_workers]) - # For pipeline parallel, we use a thread pool for asynchronous - # execute_model. - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue - # _async_aggregate_workers_output also assumes a single IO thread - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") + self.futures_queue = deque[tuple[FutureWrapper, Callable]]() self.output_rank = self._get_output_rank() - self.has_connector = self.vllm_config.kv_transfer_config is not None class AscendWorkerProc(WorkerProc): diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 846c4832644..faa57b6140f 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -23,14 +23,9 @@ # isort: off import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.worker.patch_distributed # noqa -import vllm_ascend.patch.worker.patch_logits # noqa import vllm_ascend.patch.worker.patch_roberta # noqa import vllm_ascend.patch.worker.patch_weight_loader # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_minicpm # noqa - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa - import vllm_ascend.patch.worker.patch_deepseek_v3_2 # noqa +import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa +import vllm_ascend.patch.worker.patch_rope # noqa diff --git a/vllm_ascend/patch/worker/patch_deepseek_mtp.py b/vllm_ascend/patch/worker/patch_deepseek_mtp.py deleted file mode 100644 index 5f918b2d4e9..00000000000 --- a/vllm_ascend/patch/worker/patch_deepseek_mtp.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -from transformers import PretrainedConfig -from vllm.config import VllmConfig -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.deepseek_mtp import \ - DeepSeekMultiTokenPredictorLayer -from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer -from vllm.model_executor.models.utils import maybe_prefix - - -class SharedHead(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - quant_config: QuantizationConfig = None, - ) -> None: - super().__init__() - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head"), - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.norm(hidden_states) - - -def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None: - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - # We don't need topk_indices_buffer in Ascend - topk_indices_buffer = None - self.shared_head = SharedHead(config=config, - prefix=prefix, - quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer) - - -DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init diff --git a/vllm_ascend/patch/worker/patch_deepseek_v3_2.py b/vllm_ascend/patch/worker/patch_deepseek_v3_2.py deleted file mode 100644 index cdafcb6706a..00000000000 --- a/vllm_ascend/patch/worker/patch_deepseek_v3_2.py +++ /dev/null @@ -1,108 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from itertools import islice -from typing import Optional, Union - -import torch -import vllm.model_executor.models.deepseek_v2 -from torch import nn -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding -from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer -from vllm.model_executor.models.utils import ( - PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers) -from vllm.sequence import IntermediateTensors - - -@support_torch_compile -class DeepseekV2Model(nn.Module): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - - self.vocab_size = config.vocab_size - self.is_v32 = hasattr(config, "index_topk") - topk_indices_buffer = None - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -vllm.model_executor.models.deepseek_v2.DeepseekV2Model = DeepseekV2Model diff --git a/vllm_ascend/patch/worker/patch_logits.py b/vllm_ascend/patch/worker/patch_logits.py deleted file mode 100644 index 84a92f916fe..00000000000 --- a/vllm_ascend/patch/worker/patch_logits.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -import vllm -from vllm._custom_ops import apply_repetition_penalties_torch - - -def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, - repetition_penalties: torch.Tensor) -> None: - """Apply repetition penalties to logits in-place. - - Args: - logits: The logits tensor of shape [num_seqs, vocab_size]. - prompt_mask: A boolean tensor indicating which tokens appear in the prompt. - output_mask: A boolean tensor indicating which tokens appear in the output. - repetition_penalties: The repetition penalties of shape (num_seqs, ). - """ - apply_repetition_penalties_torch(logits, prompt_mask, output_mask, - repetition_penalties) - - -# NPU device type tensors have attributes is_cuda=True and is_npu=True, according to its implementation in -# https://github.com/Ascend/pytorch/blob/863b9071cbdf47023c12c246e3efa9c6e2285fc6/torch_npu/npu/_stream_check.py#L74 -# This causes that vLLM's apply_repetition_penalties function will run into the branch of "if logits.is_cuda" and -# call the custom op implemented in CUDA, which is not compatible with NPU. -# Reference: https://github.com/vllm-project/vllm/blob/f66673a39d9f364194c249f28098cad8a5584ccb/vllm/_custom_ops.py#L314 -vllm._custom_ops.apply_repetition_penalties = apply_repetition_penalties diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index 0383da9ec91..2f5af43be48 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -1,21 +1,14 @@ -import vllm.model_executor.layers.fla.ops.chunk -import vllm.model_executor.layers.fla.ops.fused_recurrent -import vllm.model_executor.layers.fla.ops.layernorm_guard import vllm.model_executor.layers.mamba.ops.causal_conv1d -from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, - causal_conv1d_update_npu) -from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule -from vllm_ascend.ops.sigmoid_gating import ( - fused_recurrent_gated_delta_rule_fwd_kernel, - fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0) -from vllm_ascend.utils import vllm_version_is +from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule +from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn +from vllm_ascend.ops.triton.fla.sigmoid_gating import \ + fused_recurrent_gated_delta_rule_fwd_kernel +from vllm_ascend.ops.triton.mamba.casual_conv1d import ( + causal_conv1d_fn, causal_conv1d_update_npu) vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn -if vllm_version_is('0.11.0'): - vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0 -else: - vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel +vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn -vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule +vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = chunk_gated_delta_rule diff --git a/vllm_ascend/patch/worker/patch_weight_loader.py b/vllm_ascend/patch/worker/patch_weight_loader.py index cbbace8bd46..e0fcde04c31 100644 --- a/vllm_ascend/patch/worker/patch_weight_loader.py +++ b/vllm_ascend/patch/worker/patch_weight_loader.py @@ -3,13 +3,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.utils import set_weight_attrs - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import GiB_bytes -else: - from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_constants import GiB_bytes logger = init_logger(__name__) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index faed5aea148..7cc84fc6ae3 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -24,16 +24,40 @@ from vllm.platforms import Platform, PlatformEnum # todo: please remove it when solve cuda hard code in vllm -os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "True" +os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p, - prefill_context_parallel_enable, - update_aclgraph_sizes, - update_cudagraph_capture_sizes, vllm_version_is) + +# isort: off +from vllm_ascend.utils import ( + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType, + enable_sp, get_ascend_device_type, is_vl_model, + prefill_context_parallel_enable, update_aclgraph_sizes, + update_cudagraph_capture_sizes, update_default_aclgraph_sizes) + +# set custom ops path +CUR_DIR = os.path.dirname(os.path.realpath(__file__)) +CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "vllm_ascend", "_cann_ops_custom", + "vendors", "customize") +CUSTOM_LIB_PATH = os.path.join(CUSTOM_OPP_PATH, "op_api", "lib") + +if os.path.exists(CUSTOM_OPP_PATH): + current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "") + if current_cust_opp_path: + os.environ[ + "ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}" + else: + os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH + +if os.path.exists(CUSTOM_LIB_PATH): + current_lib_path = os.environ.get("LD_LIBRARY_PATH", "") + if current_lib_path: + os.environ["LD_LIBRARY_PATH"] = f"{CUSTOM_LIB_PATH}:{current_lib_path}" + else: + os.environ["LD_LIBRARY_PATH"] = CUSTOM_LIB_PATH if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -54,7 +78,9 @@ class NPUPlatform(Platform): device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" dispatch_key: str = "PrivateUse1" - supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD] + supported_quantization: list[str] = [ + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD + ] def is_sleep_mode_available(self) -> bool: return True @@ -77,6 +103,8 @@ def pre_register_and_update(cls, if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) + from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ + AscendCompressedTensorsConfig # noqa: F401 from vllm_ascend.quantization.quant_config import \ AscendQuantConfig # noqa: F401 @@ -119,10 +147,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # initialize ascend config from vllm additional_config ascend_config = init_ascend_config(vllm_config) - if vllm_version_is("0.11.0"): - from vllm.config import CompilationLevel - else: - from vllm.config import CompilationMode # noqa: E402 + from vllm.config import CompilationMode # noqa: E402 compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config @@ -148,29 +173,19 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config.compilation import CUDAGraphMode if enforce_eager: logger.info("Compilation disabled, using eager mode by default") - if vllm_version_is("0.11.0"): - compilation_config.level = CompilationLevel.NO_COMPILATION - else: - compilation_config.mode = CompilationMode.NONE + compilation_config.mode = CompilationMode.NONE + if compilation_config.splitting_ops is None: + compilation_config.splitting_ops = [] compilation_config.cudagraph_num_of_warmups = 1 - if vllm_version_is("0.11.0"): - if compilation_config.level not in [ - CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE - ]: - logger.warning( - "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", - compilation_config.level) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - else: - if compilation_config.mode not in [ - CompilationMode.NONE, CompilationMode.VLLM_COMPILE - ]: - logger.warning( - "NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", - compilation_config.mode) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if compilation_config.mode not in [ + CompilationMode.NONE, CompilationMode.VLLM_COMPILE + ]: + logger.warning( + "NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", + compilation_config.mode) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. if ascend_config.torchair_graph_config.enabled: @@ -193,6 +208,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # set cudaprah sizes before extending `compilation_config.splitting_ops` vllm_config._set_cudagraph_sizes() + # There are cases where default cudagraph_capture_sizes are not friendly + # to ascend ops && hardwares. We update these sizes here to improve + # default performance. + update_default_aclgraph_sizes(vllm_config) # TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism # is supported by vllm-ascend. if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \ @@ -206,96 +225,49 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: f"{vllm_config.parallel_config.tensor_parallel_size}") if len(sp_aclgraph_sizes) != len(original_sizes): compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes - if vllm_version_is("0.11.0"): - compilation_config.init_with_cudagraph_sizes( - sp_aclgraph_sizes) - else: - update_cudagraph_capture_sizes(vllm_config, - sp_aclgraph_sizes) + update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes) # TODO: Full graph is fully supported later, and the default value will be set to full graph. if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - if vllm_version_is("0.11.0"): - if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: - compilation_config.level = CompilationLevel.NO_COMPILATION - elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: - logger.info( - "PIECEWISE compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") - assert compilation_config.level == CompilationLevel.PIECEWISE, \ - "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" - compilation_config.set_splitting_ops_for_v1() - compilation_config.use_inductor = False - compilation_config.splitting_ops.extend([ - "vllm.unified_ascend_attention_with_output", - "vllm.mla_forward" - ]) - update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ - compilation_config.cudagraph_mode == CUDAGraphMode.FULL: - logger.info( - "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") - compilation_config.use_inductor = False - warning_message = """\033[91m - ********************************************************************************** - * WARNING: You have enabled the *full graph* feature. - * This is an early experimental stage and may involve various unknown issues. - * A known problem is that capturing too many batch sizes can lead to OOM - * (Out of Memory) errors or inference hangs. If you encounter such issues, - * consider reducing `gpu_memory_utilization` or manually specifying a smaller - * batch size for graph capture. - * For more details, please refer to: - * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs - **********************************************************************************\033[0m - """ - logger.warning(warning_message) - else: - logger.info( - "%s cudagraph_mode is not support on NPU. falling back to NONE", - compilation_config.cudagraph_mode) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - compilation_config.level = CompilationLevel.NO_COMPILATION + if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + compilation_config.mode = CompilationMode.NONE + elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.info( + "PIECEWISE compilation enabled on NPU. use_inductor not supported - " + "using only ACL Graph mode") + assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \ + "When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE" + compilation_config.set_splitting_ops_for_v1() + compilation_config.use_inductor = False + compilation_config.splitting_ops.extend(["vllm::mla_forward"]) + update_aclgraph_sizes(vllm_config) + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: + logger.info( + "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " + "using only ACL Graph mode") + compilation_config.use_inductor = False + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + logger.warning(warning_message) else: - if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: - compilation_config.mode = CompilationMode.NONE - elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: - logger.info( - "PIECEWISE compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") - assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \ - "When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE" - compilation_config.set_splitting_ops_for_v1() - compilation_config.use_inductor = False - compilation_config.splitting_ops.extend(["vllm::mla_forward"]) - update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ - compilation_config.cudagraph_mode == CUDAGraphMode.FULL: - logger.info( - "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") - compilation_config.use_inductor = False - warning_message = """\033[91m - ********************************************************************************** - * WARNING: You have enabled the *full graph* feature. - * This is an early experimental stage and may involve various unknown issues. - * A known problem is that capturing too many batch sizes can lead to OOM - * (Out of Memory) errors or inference hangs. If you encounter such issues, - * consider reducing `gpu_memory_utilization` or manually specifying a smaller - * batch size for graph capture. - * For more details, please refer to: - * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs - **********************************************************************************\033[0m - """ - logger.warning(warning_message) - else: - logger.info( - "%s cudagraph_mode is not support on NPU. falling back to NONE", - compilation_config.cudagraph_mode) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - compilation_config.mode = CompilationMode.NONE + logger.info( + "%s cudagraph_mode is not support on NPU. falling back to NONE", + compilation_config.cudagraph_mode) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.mode = CompilationMode.NONE # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # Then, we will have to discuss the error handling strategy and user experience @@ -310,7 +282,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. - os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv" + parallel_config.all2all_backend = "flashinfer_all2allv" if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" else: @@ -336,7 +308,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size = origin_block_size # Activate custom ops for v1, except on 310P - if not is_310p(): + if get_ascend_device_type() != AscendDeviceType._310P: compilation_config.custom_ops = ["all"] # If ascend_scheduler_config is enabled, @@ -372,6 +344,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "needs to be equal if use cp or dcp > 1 in P/D disaggregate scenario." ) + if is_vl_model(vllm_config): + if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \ + bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))): + raise ValueError( + "Currently, VL models doesn't support " + "FLASHCOMM in vllm-ascend. We will fix this in the future. " + "Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.") + @classmethod def import_kernels(cls) -> None: # Directly importing vllm_ascend_C prevents ASCEND_RT_VISIBLE_DEVICES @@ -391,14 +371,11 @@ def get_attn_backend_cls( dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink=False, use_sparse=False, + attn_type: str | None = None, ): - if not use_v1: - raise ValueError("vLLM Ascend does not support V0 engine.") - ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: @@ -427,10 +404,7 @@ def get_attn_backend_cls( @classmethod def get_punica_wrapper(cls) -> str: - if vllm_version_is("0.11.0"): - return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU0110" - else: - return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU" + return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU" @classmethod def get_current_memory_usage(cls, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index c0760c800ed..72c04e50b70 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -94,8 +94,10 @@ def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - if torch.npu.is_available(): - return ASCEND_QUANTIZATION_METHOD + if hf_quant_cfg is not None: + quant_method = hf_quant_cfg.get("quant_method", None) + if quant_method is None and torch.npu.is_available(): + return ASCEND_QUANTIZATION_METHOD return None def get_quant_method(self, layer: torch.nn.Module, @@ -113,7 +115,7 @@ def get_quant_method(self, layer: torch.nn.Module, self.packed_modules_mapping): return AscendUnquantizedLinearMethod() return AscendLinearMethod(self, prefix, - self.packed_modules_mapping) + self.packed_modules_mapping, layer) elif isinstance(layer, Attention) and \ 'fa_quant_type' in self.quant_description.keys() and \ self.quant_description['fa_quant_type'] is not None: @@ -126,13 +128,13 @@ def get_quant_method(self, layer: torch.nn.Module, self.packed_modules_mapping): return AscendUnquantizedFusedMoEMethod(layer.moe_config) return AscendFusedMoEMethod(self, prefix, - self.packed_modules_mapping) + self.packed_modules_mapping, layer) elif isinstance(layer, VocabParallelEmbedding): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedEmbeddingMethod() return AscendEmbeddingMethod(self, prefix, - self.packed_modules_mapping) + self.packed_modules_mapping, layer) return None def is_layer_skipped_ascend( @@ -222,6 +224,8 @@ def get_scaled_act_names(self) -> List[str]: ], "gate_up_proj": ["gate_proj", "up_proj"], "in_proj": ["in_proj_qkvz", "in_proj_ba"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] }, "qwen2_5_vl": { "qkv_proj": [ @@ -257,11 +261,16 @@ class AscendLinearMethod(LinearMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]) -> None: + def __init__(self, + quant_config: AscendQuantConfig, + prefix: str, + packed_modules_mapping: Dict[str, Any] | None, + layer: torch.nn.Module = None) -> None: self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "linear", - packed_modules_mapping) + prefix, + "linear", + packed_modules_mapping, + layer=layer) def create_weights( self, @@ -399,11 +408,16 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]): + def __init__(self, + quant_config: AscendQuantConfig, + prefix: str, + packed_modules_mapping: Dict[str, Any], + layer: torch.nn.Module = None): self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "moe", - packed_modules_mapping) + prefix, + "moe", + packed_modules_mapping, + layer=layer) def create_weights( self, @@ -483,7 +497,10 @@ class AscendEmbeddingMethod(AscendLinearMethod): """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]) -> None: + packed_modules_mapping: Dict[str, Any], + layer: torch.nn.Module) -> None: self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "linear", - packed_modules_mapping) + prefix, + "linear", + packed_modules_mapping, + layer=layer) diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 6d914c0dade..eaaaee86702 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -1,7 +1,10 @@ from typing import Any, Dict, Optional, Type +import torch from vllm.logger import logger +from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD + from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) @@ -60,8 +63,28 @@ def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, def get_quant_method(quant_description: Dict[str, Any], prefix: str, layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - logger.info_once("Using the vLLM Ascend Quantization now!") + packed_modules_mapping: Optional[Dict[str, Any]] = None, + layer: torch.nn.Module = None): + if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD: + return get_quant_method_llmcompressor(layer) + + return get_quant_method_modelslim(quant_description, prefix, layer_type, + packed_modules_mapping) + + +def get_quant_method_llmcompressor(layer: torch.nn.Module): + logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") + if layer.scheme is None: + raise ValueError("A scheme must be defined for each layer") + return layer.scheme + + +def get_quant_method_modelslim( + quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + logger.info_once("Using the vLLM Ascend modelslim Quantization now!") if packed_modules_mapping is None: packed_modules_mapping = dict() # Attention diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 77f0f4b23cb..c7f1dfabb86 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -342,7 +342,7 @@ def apply( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, - enable_force_load_balance: bool = True, + enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, @@ -371,7 +371,8 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + topk_ids = torch.randint_like( + topk_ids, 0, global_num_experts - global_redundant_expert_num) topk_weights = topk_weights.to(x.dtype) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index dcd692acfb6..8a7bbfe7263 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -25,7 +25,9 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, + COMPRESSED_TENSORS_METHOD, AscendDeviceType, + get_ascend_device_type, is_enable_nz) def quant_per_tensor(in_tensor: torch.Tensor, @@ -45,7 +47,8 @@ class AscendW8A8LinearMethod: def __init__(self) -> None: # aclnn quant matmul requires to transpose matrix B, set to true by default. - self.transpose_weight = not is_310p() + self.transpose_weight = get_ascend_device_type( + ) != AscendDeviceType._310P @staticmethod def get_weight( @@ -147,7 +150,11 @@ def apply( ) quant_bias = layer.quant_bias if tp_rank == 0 else None - if is_310p(): + if getattr(layer, "ascend_quant_method", + "") == COMPRESSED_TENSORS_METHOD: + quant_bias = bias + + if get_ascend_device_type() == AscendDeviceType._310P: # On 300I Duo platform, we need transpose again if # using nz. This transpose can be skipped in torchair. output = torch_npu.npu_quant_matmul( @@ -185,6 +192,11 @@ def process_weights_after_loading(self, layer): layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + if getattr(layer, "ascend_quant_method", + "") == COMPRESSED_TENSORS_METHOD: + deq_scale = layer.input_scale.data * layer.weight_scale.data + layer.deq_scale = torch.nn.Parameter(deq_scale, + requires_grad=False) class AscendW8A8FusedMoEMethod: @@ -299,7 +311,7 @@ def apply( e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: return fused_experts_310p(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, @@ -328,7 +340,7 @@ def apply( expert_map=expert_map) def process_weights_after_loading(self, layer): - if not is_310p(): + if get_ascend_device_type() != AscendDeviceType._310P: layer.w13_weight.data = layer.w13_weight.data.transpose( 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( @@ -345,7 +357,7 @@ def process_weights_after_loading(self, layer): expanding_factor_w13 = layer.w13_weight.data.shape[1] expanding_factor_w2 = layer.w2_weight.data.shape[1] - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: layer.w13_input_scale.data = torch.nn.Parameter( layer.w13_input_scale.data.max()) layer.w2_input_scale.data = torch.nn.Parameter( @@ -365,7 +377,8 @@ def process_weights_after_loading(self, layer): # converting ACL_FORMAT_FRACTAL_NZ. # npu_quant_grouped_matmul_dequant in eager mode does not accept # ACL_FORMAT_FRACTAL_NZ. - if not is_310p() and is_enable_nz(): + if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz( + ): layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() layer.w2_weight.data = torch_npu.npu_format_cast( diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 8bef2567979..6b7d6b0875c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -19,20 +19,14 @@ import torch import torch_npu -from vllm.config import get_current_vllm_config +from vllm.config import CompilationMode, get_current_vllm_config from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz, - vllm_version_is) - -if vllm_version_is("0.11.0"): - from vllm.config import CompilationLevel -else: - from vllm.config import CompilationMode +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz class AscendW8A8DynamicLinearMethod: @@ -129,18 +123,10 @@ def __init__(self): vllm_config = get_current_vllm_config() ascend_config = get_ascend_config() - if vllm_version_is("0.11.0"): - self.use_aclgraph = ( - vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and not vllm_config.model_config.enforce_eager - and not ascend_config.torchair_graph_config.enabled) - else: - self.use_aclgraph = ( - vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE - and not vllm_config.model_config.enforce_eager - and not ascend_config.torchair_graph_config.enabled) + self.use_aclgraph = ( + vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE + and not vllm_config.model_config.enforce_eager + and not ascend_config.torchair_graph_config.enabled) self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path self.in_dtype = vllm_config.model_config.dtype @@ -213,7 +199,7 @@ def apply( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, - enable_force_load_balance: bool = True, + enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, @@ -242,7 +228,8 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + topk_ids = torch.randint_like( + topk_ids, 0, global_num_experts - global_redundant_expert_num) topk_weights = topk_weights.to(self.in_dtype) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 0bf8b6bf67c..a17f534045e 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -6,16 +6,10 @@ import vllm.v1.sample.rejection_sampler as rs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (RejectionSampler, + apply_sampling_constraints, generate_uniform_probs) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.v1.sample.rejection_sampler import compute_probs -else: - from vllm.v1.sample.rejection_sampler import apply_sampling_constraints - PLACEHOLDER_TOKEN_ID = -1 GREEDY_TEMPERATURE = -1 # Maximum number of speculative draft tokens allowed per request in a single @@ -89,19 +83,12 @@ def forward( # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the # `compute_probs` function. - if vllm_version_is("0.11.0"): - target_probs = compute_probs( - target_logits, - metadata.cu_num_draft_tokens, - sampling_metadata, - ) - else: - target_logits = apply_sampling_constraints( - target_logits, - metadata.cu_num_draft_tokens, - sampling_metadata, - ) - target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + target_logits = apply_sampling_constraints( + target_logits, + metadata.cu_num_draft_tokens, + sampling_metadata, + ) + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) output_token_ids = rejection_sample( metadata.draft_token_ids, diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 37abdd4965a..6c9f37c64b4 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -3,7 +3,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample from vllm.v1.sample.sampler import Sampler -from vllm_ascend.utils import is_310p +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type DEFAULT_LOGPROBS_MODE = "raw_logprobs" @@ -25,7 +25,8 @@ def _apply_top_k_top_p( p: torch.Tensor, ) -> torch.Tensor: # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P - if not is_310p() and p is not None and k is not None and 1 <= int( + if get_ascend_device_type( + ) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int( k.max()) <= 1024: # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) return torch_npu.npu_top_k_top_p(logits, p, k) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 74e2917806b..75f01ee9bdb 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -5,13 +5,15 @@ import torch import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -22,14 +24,6 @@ AscendMetadata) from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.config import CompilationLevel - from vllm.utils import is_pin_memory_available -else: - from vllm.config import CompilationMode - from vllm.utils.platform_utils import is_pin_memory_available PADDING_SLOT_ID = -1 @@ -52,16 +46,9 @@ def __init__(self, self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size( ) - if vllm_version_is("0.11.0"): - self.use_cuda_graph = ( - self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and not self.vllm_config.model_config.enforce_eager) - else: - self.use_cuda_graph = ( - self.vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE - and not self.vllm_config.model_config.enforce_eager) + self.use_cuda_graph = (self.vllm_config.compilation_config.mode + == CompilationMode.VLLM_COMPILE and + not self.vllm_config.model_config.enforce_eager) self.cudagraph_batch_sizes = list( reversed( @@ -137,8 +124,7 @@ def dummy_run(self, num_tokens_across_dp: Optional[torch.Tensor] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None): - moe_comm_type = self.runner._select_moe_comm_method( - num_tokens, with_prefill) + moe_comm_type = self.runner._select_moe_comm_method(num_tokens) with set_ascend_forward_context(None, self.vllm_config, moe_comm_type=moe_comm_type, @@ -150,7 +136,7 @@ def dummy_run(self, ) def generate_token_ids(self, - valid_sampled_token_ids: list[list[int]], + valid_sampled_token_ids: list[np.ndarray], sampling_metadata: SamplingMetadata = None, scheduler_output: SchedulerOutput = None, spec_decode_metadata: SpecDecodeMetadata = None, @@ -163,7 +149,7 @@ def generate_token_ids(self, attn_metadata = self._get_eagle_atten_dict(scheduler_output) next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: + if token_ids.shape[0] > 0: # Common case. next_token_id = token_ids[-1] else: @@ -175,7 +161,7 @@ def generate_token_ids(self, scheduler_output.num_scheduled_tokens[req_id]) next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) + next_token_ids.append(next_token_id.item()) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) @@ -195,7 +181,7 @@ def generate_token_ids(self, else: num_draft_tokens = spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + n + 1 - valid_sampled_token_ids[i].shape[0] if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens = torch.tensor( @@ -473,11 +459,7 @@ def _propose( else: num_input_tokens = num_tokens - with_prefill = attn_metadata.attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - moe_comm_type = self.runner._select_moe_comm_method( - num_input_tokens, with_prefill) + moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions.to(device) @@ -517,8 +499,7 @@ def _propose( else: input_batch_size = batch_size - moe_comm_type = self.runner._select_moe_comm_method( - input_batch_size, False) + moe_comm_type = self.runner._select_moe_comm_method(input_batch_size) attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 3f0a36b13cd..5fdb494515f 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -1,6 +1,7 @@ import enum from typing import Optional +import numpy as np import torch from vllm.config import CUDAGraphMode, VllmConfig from vllm.v1.core.sched.output import SchedulerOutput @@ -40,7 +41,7 @@ def dummy_run(self, raise NotImplementedError def generate_token_ids(self, - valid_sampled_token_ids: list[list[int]], + valid_sampled_token_ids: list[np.ndarray], sampling_metadata: SamplingMetadata = None, scheduler_output: SchedulerOutput = None, spec_decode_metadata: SpecDecodeMetadata = None, diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 627411fed4a..73b65aedfd9 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,5 +1,5 @@ import importlib -from typing import Optional +from typing import Optional, Union import numpy as np import torch @@ -7,7 +7,7 @@ import torch.nn.functional as F from vllm.config import (CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import BatchDescriptor +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader @@ -15,14 +15,7 @@ process_weights_after_loading from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv - +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput @@ -32,21 +25,20 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + set_mtp_graph_params, + update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, - prefill_context_parallel_enable, - vllm_version_is) + prefill_context_parallel_enable) if prefill_context_parallel_enable(): from vllm.distributed import get_pcp_group -if vllm_version_is("0.11.0"): - from vllm.model_executor.model_loader.utils import set_default_torch_dtype - from vllm.utils import is_pin_memory_available -else: - from vllm.utils.platform_utils import is_pin_memory_available - from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) @@ -75,6 +67,9 @@ def _load_model(architecture): class MtpProposer(Proposer): + # TODO: Find out why ModelRunner does not this explicit typing? + model: Union[nn.Module, ACLGraphWrapper] + def __init__( self, vllm_config: VllmConfig, @@ -203,6 +198,15 @@ def load_model(self, model) -> None: process_weights_after_loading(self.model, draft_model_config, target_device) + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ): + self.update_stream: torch.npu.Stream = torch.npu.Stream() + set_mtp_graph_params( + self.vllm_config.compilation_config.cudagraph_capture_sizes) + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + @torch.inference_mode() def dummy_run(self, num_tokens: int, @@ -219,15 +223,57 @@ def dummy_run(self, with_prefill, ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) - moe_comm_type = self.runner._select_moe_comm_method( - num_tokens, with_prefill) + moe_comm_type = self.runner._select_moe_comm_method(num_tokens) + + if skip_attn: + attn_metadata = None + elif aclgraph_runtime_mode == CUDAGraphMode.FULL: + if len(self.runner.attn_groups) > 0: + num_computed_tokens_cpu = ( + self.runner.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs]) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.runner. + query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + seq_lens=self.runner.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=self.num_speculative_tokens + 1, + num_computed_tokens_cpu=num_computed_tokens_cpu, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor()[:num_reqs], + slot_mapping=self.runner.input_batch.block_table[0]. + slot_mapping, + positions=self.runner.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + cos=self.runner.cos, + sin=self.runner.sin, + ) - attn_metadata = None + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build_for_graph_capture( + common_attn_metadata, AscendAttentionState.SpecDecoding, + self.runner.get_model()) + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp + else: + attn_metadata = None + else: + attn_metadata = None input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens] previous_hidden_states = self.hidden_states[:num_tokens] - for _ in range(self.num_speculative_tokens): + for i in range(self.num_speculative_tokens): + if i > 0: + aclgraph_runtime_mode = CUDAGraphMode.NONE with set_ascend_forward_context( attn_metadata, self.vllm_config, @@ -239,15 +285,25 @@ def dummy_run(self, in_profile_run=self.runner.in_profile_run, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + is_mtp_model=True): self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing: + if self.vllm_config.model_config.use_mla: + update_mla_attn_params( + self.update_stream, forward_context, + positions.shape[0], + self.vllm_config.speculative_config) if with_prefill: break def generate_token_ids(self, - sampled_token_ids: list[list[int]], + sampled_token_ids: Union[torch.Tensor, + list[np.ndarray]], sampling_metadata: SamplingMetadata = None, scheduler_output: SchedulerOutput = None, spec_decode_metadata: SpecDecodeMetadata = None, @@ -324,6 +380,8 @@ def generate_token_ids(self, common_attn_metadata.query_start_loc = \ query_start_loc_pcp_full[:num_reqs + 1] if self.speculative_config.disable_padded_drafter_batch: + assert isinstance(sampled_token_ids, list) + # NOTE: Currently, MTP-fullgraph is incompatibility with pcp token_indices_to_sample = None common_attn_metadata, token_indices =\ self._prepare_inputs( @@ -358,6 +416,8 @@ def generate_token_ids(self, long_seq_metadata=long_seq_metadata, num_prefill_reqs=num_prefill_reqs, num_decode_reqs=num_decode_reqs, + scheduler_output=scheduler_output, + num_scheduled_tokens=num_scheduled_tokens, ) return draft_token_ids @@ -379,7 +439,7 @@ def _get_attn_metadata(self, attn_metadata): def _prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ @@ -460,6 +520,13 @@ def _prepare_inputs( token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) + common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_( + common_attn_metadata.slot_mapping[token_indices]) + common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) + + # NOTE: Currently positions and seq_lens are not used in mla_v1 forward + # so we do not need to fixed them. But if they are used in the future, + # we should fixed them. spec_common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), @@ -472,7 +539,7 @@ def _prepare_inputs( num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], + slot_mapping=common_attn_metadata.slot_mapping, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, positions=common_attn_metadata.positions[token_indices], attn_mask=self.runner.attn_mask, @@ -502,6 +569,8 @@ def _propose( long_seq_metadata=None, num_prefill_reqs=0, num_decode_reqs=0, + scheduler_output: SchedulerOutput = None, + num_scheduled_tokens: int = 0, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -585,14 +654,11 @@ def _propose( assert self.runner is not None - builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_mtp = builder.build(0, common_attn_metadata, - self.runner.get_model()) - attn_metadata = {} - for layer_name in self.attn_layer_name: - attn_metadata[layer_name] = attn_metadata_mtp - - if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: # Acl graph mode, add padding to the batch size num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: @@ -607,21 +673,40 @@ def _propose( with_prefill) = self.runner._sync_metadata_across_dp( num_input_tokens, self.runner.with_prefill) - moe_comm_type = self.runner._select_moe_comm_method( - num_input_tokens, with_prefill) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=False) + moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) + + if scheduler_output: + max_query_len = common_attn_metadata.max_query_len + uniform_decode = (max_query_len in list( + range(1, self.num_speculative_tokens + + 2))) and (scheduler_output.total_num_scheduled_tokens + == self.runner.input_batch.num_reqs * + (self.num_speculative_tokens + 1)) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + else: + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) - if aclgraph_runtime_mode not in [ - CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE - ]: - # Fallback to piecewise graph, when acl full graph is enabled - logger.debug( - "Currently the eagle proposer only supports cudagraph_mode " - f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " - "to CUDAGraphMode.PIECEWISE") - aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE + + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ) and aclgraph_runtime_mode == CUDAGraphMode.FULL: + graph_pad_size = num_input_tokens + else: + # Currently, runner.graph_pad_size will always be -1. + graph_pad_size = self.runner.graph_pad_size + + # If use fullgraph and disable_padded_drafter_batch=True, We need to + # update the graph_pad_size in common_attn_metadata, to tell the + # builder padding some elements. + common_attn_metadata.graph_pad_size = graph_pad_size + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build(0, common_attn_metadata, + self.runner.get_model()) + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp for step in range(self.num_speculative_tokens): with set_ascend_forward_context( @@ -635,7 +720,8 @@ def _propose( aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, in_profile_run=self.runner.in_profile_run, - num_actual_tokens=num_tokens): + num_actual_tokens=num_tokens, + is_mtp_model=True): with ProfileExecuteDuration().capture_async('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata @@ -644,6 +730,13 @@ def _propose( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens]) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + if self.vllm_config.model_config.use_mla: + update_mla_attn_params( + self.update_stream, forward_context, + num_input_tokens, + self.vllm_config.speculative_config) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): @@ -699,12 +792,21 @@ def _propose( input_ids = draft_token_ids_list[-1].int() positions += 1 - attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ - 1:batch_size + 1].tolist() - attn_metadata_i.decode.cos = builder.cos_cache[ - positions].unsqueeze(1).unsqueeze(2) - attn_metadata_i.decode.sin = builder.sin_cache[ - positions].unsqueeze(1).unsqueeze(2) + # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. + if self.speculative_config.disable_padded_drafter_batch or \ + aclgraph_runtime_mode != CUDAGraphMode.FULL: + attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + 1:batch_size + 1].tolist() + if aclgraph_runtime_mode == CUDAGraphMode.FULL: + attn_metadata_i.decode.actual_seq_lengths_q = \ + builder.pad_actual_seq_len_q_mtp_disable_pad( + graph_pad_size - batch_size, + batch_size, + attn_metadata_i.decode.actual_seq_lengths_q) + attn_metadata_i.decode.cos = builder.cos_cache[ + positions].unsqueeze(1).unsqueeze(2) + attn_metadata_i.decode.sin = builder.sin_cache[ + positions].unsqueeze(1).unsqueeze(2) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch @@ -735,6 +837,10 @@ def _propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:hidden_states.shape[0]] = hidden_states attn_metadata_i.slot_mapping[:batch_size] = slot_mapping + if self.speculative_config.disable_padded_drafter_batch: + self.positions[batch_size:num_input_tokens] = 0 + self.input_ids[batch_size:num_input_tokens] = 0 + self.hidden_states[batch_size:num_input_tokens].fill_(0) if attn_metadata_i.prefill is not None: attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens @@ -751,6 +857,12 @@ def _propose( attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( ) + decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list + if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ + self.speculative_config.disable_padded_drafter_batch: + attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [ + 0 + ] * (graph_pad_size - len(decode_seq_lens_list)) attn_metadata_i.decode.input_positions = self.positions[: num_input_tokens] attn_metadata_i.decode.max_seq_lens += 1 @@ -785,7 +897,7 @@ def _prepare_input_kernel(self, out_ptr: torch.Tensor, def prepare_next_token_ids_cpu( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, num_scheduled_tokens: dict[str, int], @@ -800,7 +912,7 @@ def prepare_next_token_ids_cpu( req_ids = gpu_input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): - if token_ids: + if token_ids.shape[0] > 0: # Common case. next_token_id = token_ids[-1] else: @@ -811,7 +923,7 @@ def prepare_next_token_ids_cpu( seq_len = req_state.num_computed_tokens + num_scheduled_tokens[ req_id] next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) + next_token_ids.append(next_token_id.item()) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.input_ids.device) @@ -915,6 +1027,9 @@ def prepare_inputs_padded( total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] + # NOTE: Currently positions and seq_lens are not used in mla_v1 forward + # so we do not need to fixed them. But if they are used in the future, + # we should fixed them. spec_common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, query_start_loc_cpu=query_start_loc_cpu, diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 932a127cf01..065d290fa44 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -1,3 +1,4 @@ +import numpy as np import torch from vllm.config import CUDAGraphMode from vllm.v1.spec_decode.ngram_proposer import \ @@ -30,7 +31,7 @@ def dummy_run(self, pass def generate_token_ids(self, - valid_sampled_token_ids, + valid_sampled_token_ids: list[np.ndarray], sampling_metadata=None, scheduler_output=None, spec_decode_metadata=None, @@ -41,7 +42,7 @@ def generate_token_ids(self, aux_hidden_states=None) -> list[list[int]]: valid_ngram_requests = [] for i, sampled_ids in enumerate(valid_sampled_token_ids): - num_sampled_ids = len(sampled_ids) + num_sampled_ids = sampled_ids.shape[0] if not num_sampled_ids: continue diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py index a61abbdcdbe..b7128c40105 100644 --- a/vllm_ascend/torchair/models/qwen2.py +++ b/vllm_ascend/torchair/models/qwen2.py @@ -248,7 +248,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -319,8 +319,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index 3ea3a56f061..e6a5ad543e6 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -23,7 +23,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, CompilationMode, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) @@ -55,12 +55,6 @@ from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp) from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.config import CompilationLevel -else: - from vllm.config import CompilationMode class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -299,16 +293,10 @@ def __init__( layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) - if vllm_version_is("0.11.0"): - self.use_aclgraph = (vllm_config is not None - and vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not vllm_config.model_config.enforce_eager) - else: - self.use_aclgraph = (vllm_config is not None - and vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE and - not vllm_config.model_config.enforce_eager) + self.use_aclgraph = (vllm_config is not None + and vllm_config.compilation_config.mode + == CompilationMode.VLLM_COMPILE + and not vllm_config.model_config.enforce_eager) if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): @@ -438,7 +426,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index f67a0ff09c0..c153a86c1e1 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -32,6 +32,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata +from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -74,12 +75,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ TorchairAscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor, oproj_tp_enable, vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.attention import Attention -else: - from vllm.attention.layer import MLAAttention +from vllm_ascend.utils import dispose_tensor, oproj_tp_enable class Indexer(nn.Module): @@ -616,67 +612,31 @@ def __init__( # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) # i.e. # kv_lora_rank + qk_rope_head_dim == head_size - if vllm_version_is("0.11.0"): - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - use_sparse=False, - indexer=None, - # SFA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_a_proj=self.q_a_proj - if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj - if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - decoder_layer=decoder_layer, - ) - else: - self.mla_attn = MLAAttention( - num_heads=self.num_local_heads, - scale=self.scaling, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_sparse=False, - indexer=None, - # MLA Args - rotary_emb=self.rotary_emb, - q_a_proj=self.q_a_proj - if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj - if self.q_lora_rank is None else self.q_b_proj, - q_b_proj=self.q_b_proj - if self.q_lora_rank is not None else None, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) + self.mla_attn = MLAAttention( + num_heads=self.num_local_heads, + scale=self.scaling, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_sparse=False, + indexer=None, + # MLA Args + rotary_emb=self.rotary_emb, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) def forward( self, @@ -882,66 +842,30 @@ def __init__( index_topk=self.index_topk, prefix=f"{prefix}.indexer", ) - - if vllm_version_is("0.11.0"): - self.sfa_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - use_sparse=True, - indexer=self.indexer, - # SFA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_a_proj=self.q_a_proj - if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj - if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - decoder_layer=decoder_layer, - ) - else: - self.sfa_attn = MLAAttention( - num_heads=self.num_local_heads, - scale=self.scaling, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_sparse=True, - indexer=self.indexer, - # MLA Args - rotary_emb=self.rotary_emb, - q_a_proj=self.q_a_proj - if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj - if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) + self.sfa_attn = MLAAttention( + num_heads=self.num_local_heads, + scale=self.scaling, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_sparse=True, + indexer=self.indexer, + # MLA Args + rotary_emb=self.rotary_emb, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) def forward( self, @@ -1235,7 +1159,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1251,7 +1175,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py index 7a0c9c0696b..d81941ff56b 100644 --- a/vllm_ascend/torchair/models/torchair_pangu_moe.py +++ b/vllm_ascend/torchair/models/torchair_pangu_moe.py @@ -57,7 +57,8 @@ from vllm.v1.sample.sampler import Sampler from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, + get_ascend_device_type) _ROUTER_SCALE = None @@ -448,7 +449,8 @@ def __init__( # on 300I Duo platform, we find that num_voted_experts set to 5 achieves # good performance without sacrifice too much accuracy. for other platform, # this is set to 8 to use original pangu grouped topk. - num_voted_experts = 5 if is_310p() else 8 + num_voted_experts = 5 if get_ascend_device_type( + ) == AscendDeviceType._310P else 8 self.experts = FusedMoE( num_experts=config.num_experts, @@ -808,7 +810,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -824,7 +826,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -916,8 +918,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, @@ -1109,7 +1111,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) - if is_310p() and "head" in name: + if get_ascend_device_type( + ) == AscendDeviceType._310P and "head" in name: # on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than # ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented # by linear, we manually cast the format here. diff --git a/vllm_ascend/torchair/ops/torchair_activation.py b/vllm_ascend/torchair/ops/torchair_activation.py index 0721ea0a7f2..0089b663253 100644 --- a/vllm_ascend/torchair/ops/torchair_activation.py +++ b/vllm_ascend/torchair/ops/torchair_activation.py @@ -28,9 +28,9 @@ def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu - from vllm_ascend.utils import is_310p + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) else: out = torch_npu.npu_swiglu(x) diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 8cb2ff9e62f..4408b310c70 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -43,8 +43,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, - determine_default_log2phy_map) +from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding @@ -52,10 +51,9 @@ get_rm_router_logits_state, npu_stream_switch, npu_wait_tensor, super_kernel) -from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, - get_ascend_soc_version, is_310p, - is_hierarchical_communication_enabled, - vllm_version_is) +from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, + get_ascend_device_type, + is_hierarchical_communication_enabled) def torchair_fused_experts_with_mc2( @@ -77,11 +75,11 @@ def torchair_fused_experts_with_mc2( ep_world_size = moe_parallel_config.ep_size # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93 or is_torchair) # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. @@ -469,7 +467,7 @@ def torchair_fused_experts_moge( group_list=group_list, )[0] - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( torch.float16) else: @@ -1042,7 +1040,7 @@ def __init__( self.expert_map_path) and os.access(self.expert_map_path, os.R_OK): self.expert_load_balancer = ExpertLoadBalancer( - self.expert_map_path, self.global_num_experts) + self.expert_map_path, num_experts) self.expert_load_balancer.check_expert_map_tensor() self.global_redundant_expert_num = ( self.expert_load_balancer.get_global_redundant_expert_num()) @@ -1052,15 +1050,14 @@ def __init__( self.moe_instance_id, self.ep_rank)) self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( self.moe_instance_id, self.ep_rank).npu() + self.global_num_experts = num_experts + self.global_redundant_expert_num except Exception as e: logger.warning( f"Init expert map of mtp/eagle when using sample.{e}") - self.local_num_experts, self.expert_map = determine_default_expert_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num) + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num).npu() + self.global_num_experts, self.ep_size, self.ep_rank).npu() if self.expert_map is not None and isinstance( self.expert_map, torch.Tensor): logger.info_once( @@ -1071,21 +1068,12 @@ def __init__( get_compressed_expert_map(self.expert_map)) else: # init moe. - if vllm_version_is("0.11.0"): - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) - else: - self.local_num_experts, self.expert_map, _ = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) + self.local_num_experts, self.expert_map, _ = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) # dynamic eplb initializing with not expert_map_path if self.dynamic_eplb: - self.global_redundant_expert_num = ascend_config.init_redundancy_expert - self.local_num_experts, self.expert_map = determine_default_expert_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num) self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank, - self.global_redundant_expert_num).npu() + self.global_num_experts, self.ep_size, self.ep_rank).npu() if self.expert_map is not None and isinstance( self.expert_map, torch.Tensor): logger.info_once( diff --git a/vllm_ascend/torchair/ops/torchair_layernorm.py b/vllm_ascend/torchair/ops/torchair_layernorm.py index 583a376b801..3a3146b8dd5 100644 --- a/vllm_ascend/torchair/ops/torchair_layernorm.py +++ b/vllm_ascend/torchair/ops/torchair_layernorm.py @@ -57,9 +57,9 @@ def torchair_rmsnorm_forward_oot( import torch_npu - from vllm_ascend.utils import is_310p + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type if residual is not None: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: orig_dtype = residual.dtype x = x + residual.to(x.dtype) residual = x.to(orig_dtype) diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py index e64bd6f64b4..9fdb231b687 100644 --- a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py +++ b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py @@ -25,7 +25,8 @@ DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import enable_custom_op, is_310p +from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, + get_ascend_device_type) def custom_rotary_embedding_enabled(query, neox_style, head_size): @@ -60,8 +61,9 @@ def rope_forward_oot( if is_neox_style_override is not None: neox_style = is_neox_style_override # adopt custom kernel path for rotary_embedding - if custom_rotary_embedding_enabled(query, neox_style, - self.head_size) and not is_310p(): + if custom_rotary_embedding_enabled( + query, neox_style, self.head_size) and get_ascend_device_type( + ) != AscendDeviceType._310P: query, key = torch.ops._C_ascend.rotary_embedding( positions, query, diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index bc0a8d35783..8909bb790ef 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -28,8 +28,8 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, super_kernel) -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, - dispose_tensor, get_ascend_soc_version, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, + dispose_tensor, get_ascend_device_type, is_enable_nz, is_hierarchical_communication_enabled) @@ -234,11 +234,11 @@ def torchair_fused_experts_with_mc2( ep_world_size = ep_group.world_size # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93 or is_torchair) # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. @@ -990,7 +990,9 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + topk_ids = torch.randint_like( + topk_ids, 0, + global_num_experts - global_redundant_expert_num) topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.AllGatherEP: diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index a524a3bb4e0..16fcb385c8d 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -26,13 +26,7 @@ AttentionType) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, AscendAttentionMetadataBuilder, @@ -40,8 +34,8 @@ AscendMetadata) from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, - nd_to_nz_2d) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, + aligned_16, get_ascend_device_type, nd_to_nz_2d) class AscendAttentionTorchairBackend(AscendAttentionBackend): @@ -55,10 +49,6 @@ def get_name() -> str: def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]: return AscendAttentionTorchairBackendImpl - @staticmethod - def get_metadata_cls() -> Type["AscendTorchairMetadata"]: - return AscendTorchairMetadata - @staticmethod def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]: return AscendAttentionTorchairMetadataBuilder @@ -195,7 +185,8 @@ def build( attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state - if is_310p() and attn_state == AscendAttentionState.PrefillNoCache: + if get_ascend_device_type( + ) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache: mask_nz = nd_to_nz_2d(attn_mask) attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) @@ -391,7 +382,7 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: # align q k v output tensors query = aligned_16(query) key = aligned_16(key) diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 116b124e410..74359efe4d0 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -6,20 +6,13 @@ import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv, round_down -else: - from vllm.utils.math_utils import cdiv, round_down +from vllm.utils.math_utils import cdiv, round_down import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -43,10 +36,6 @@ class AscendMLATorchairBackend(AttentionBackend): def get_name() -> str: return "ASCEND_MLA_TORCHAIR" - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AscendMLATorchairMetadata - @staticmethod def get_builder_cls(): return AscendMLATorchairMetadataBuilder @@ -501,6 +490,11 @@ def build( num_reqs_pad_size = ( graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs) + # For the case when some request reach the max-tokens limit in this forward processing, + # so in this forward new_tokens scheduled is less than decode_token_per_req(1 + spec_token_num). + # Details can see PR:https://github.com/vllm-project/vllm/pull/27922 + num_reqs_pad_size = max(0, num_reqs_pad_size) + padded_seq_lens = seq_lens.tolist( ) + [pad_value] * num_reqs_pad_size else: diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 792972f0a6a..d7c55c6e7df 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -42,8 +42,7 @@ register_torchair_model, torchair_ops_patch, torchair_quant_method_register, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_310p, get_ascend_soc_version, - AscendSocVersion) + AscendDeviceType, get_ascend_device_type) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -125,13 +124,13 @@ def _init_mc2_tokens_capacity(self): max_num_tokens, tp_size) self.mc2_tokens_capacity = max_graph_batch_size - if get_ascend_soc_version( - ) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512: + if get_ascend_device_type( + ) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512: logger.error( f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}" ) - if get_ascend_soc_version( - ) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256: + if get_ascend_device_type( + ) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256: logger.error( f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}" ) @@ -207,7 +206,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds): if with_prefill or self.enable_shared_expert_dp: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) hidden_states = super()._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, @@ -230,7 +229,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill, assert isinstance(kv, tuple), "kv_cache must be a tuple" torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[1]) - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ) compiled_model = self._get_torchair_lazy_compiled_model(num_tokens) @@ -371,7 +370,7 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, "attn_metadata": attn_metadata } if not with_prefill: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ) compiled_model = self._get_torchair_lazy_compiled_model( padded_num_tokens_across_dp) @@ -384,7 +383,7 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, ) else: assert self.model is not None - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) hidden_states = self.model( @@ -414,7 +413,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): patch_for_hcom() - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: # on 300I Duo platform, we need to patch broadcast. however, this patch will be # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. from vllm_ascend.patch.platform.patch_distributed import \ @@ -428,7 +427,8 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): self.ascend_config.torchair_graph_config.enable_frozen_parameter # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to # disable it on 300I Duo platform now. - config.experimental_config.tiling_schedule_optimize = not is_310p() + config.experimental_config.tiling_schedule_optimize = get_ascend_device_type( + ) != AscendDeviceType._310P config.experimental_config.enable_view_optimize = \ self.ascend_config.torchair_graph_config.enable_view_optimize torch.npu.set_compile_mode(jit_compile=False) @@ -531,8 +531,8 @@ def update_torchair_graph_batch_sizes(self): # NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size` # Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same # on all EP ranks - if get_ascend_soc_version( - ) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel: + if get_ascend_device_type( + ) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel: self._align_graph_size_divisible_by_tp_size() def _align_graph_size_divisible_by_tp_size(self): diff --git a/vllm_ascend/torchair/torchair_mtp_proposer.py b/vllm_ascend/torchair/torchair_mtp_proposer.py index c26b8dd4013..476ff479966 100644 --- a/vllm_ascend/torchair/torchair_mtp_proposer.py +++ b/vllm_ascend/torchair/torchair_mtp_proposer.py @@ -1,5 +1,6 @@ import types +import numpy as np import torch import torch.nn as nn import torchair @@ -11,6 +12,7 @@ from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import \ process_weights_after_loading +from vllm.utils.torch_utils import set_default_torch_dtype from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -23,13 +25,7 @@ TorchairDeepSeekMTP from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata) -from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, - vllm_version_is) - -if vllm_version_is("0.11.0"): - from vllm.model_executor.model_loader.utils import set_default_torch_dtype -else: - from vllm.utils.torch_utils import set_default_torch_dtype +from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable PADDING_SLOT_ID = -1 @@ -86,8 +82,7 @@ def dummy_run(self, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None) -> None: - moe_comm_type = self.runner._select_moe_comm_method( - num_tokens, with_prefill) + moe_comm_type = self.runner._select_moe_comm_method(num_tokens) if not with_prefill: skip_attn = False @@ -152,7 +147,7 @@ def dummy_run(self, break def generate_token_ids(self, - valid_sampled_token_ids: list[list[int]], + valid_sampled_token_ids: list[np.ndarray], sampling_metadata: SamplingMetadata = None, scheduler_output: SchedulerOutput = None, spec_decode_metadata: SpecDecodeMetadata = None, @@ -165,7 +160,7 @@ def generate_token_ids(self, attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: + if token_ids.shape[0] > 0: # Common case. next_token_id = token_ids[-1] else: @@ -176,7 +171,7 @@ def generate_token_ids(self, seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) + next_token_ids.append(next_token_id.item()) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) @@ -192,7 +187,7 @@ def generate_token_ids(self, # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + n + 1 - valid_sampled_token_ids[i].shape[0] if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens = torch.tensor( @@ -347,8 +342,7 @@ def _propose_torchair( num_tokens_across_dp = self.runner.num_tokens_across_dp with_prefill = self.runner.with_prefill - moe_comm_type = self.runner._select_moe_comm_method( - num_input_tokens, with_prefill) + moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index 12b8d07a35d..fdaab404b8c 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -6,21 +6,13 @@ import torch.nn as nn import torch.nn.functional as F import torch_npu -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv, round_down -else: - from vllm.utils.math_utils import cdiv, round_down +from vllm.utils.math_utils import cdiv, round_down import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -43,10 +35,6 @@ class AscendSFATorchairBackend(AttentionBackend): def get_name() -> str: return "ASCEND_SFA_TORCHAIR" - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AscendSFATorchairMetadata - @staticmethod def get_builder_cls(): return AscendSFATorchairMetadataBuilder diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 5cccdaf0688..0a74bcbfdcf 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -41,6 +41,7 @@ VllmConfig = None ASCEND_QUANTIZATION_METHOD = "ascend" +COMPRESSED_TENSORS_METHOD = "compressed-tensors" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] REGISTERED_ASCEND_OPS = {} @@ -48,7 +49,6 @@ ACL_FORMAT_FRACTAL_NZ = 29 _CUSTOM_OP_ENABLED = None -_IS_310P = None _SLEEP_MODE_ENABLED = None _CURRENT_STREAM = None _PREFETCH_STREAM = None @@ -57,9 +57,9 @@ _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 _IS_MOE_MODEL = None +_IS_VL_MODEL = None _ENABLE_SP = None _HAS_LAYER_IDX = None -_ENABLE_NZ = None _SUBSCRIBED_COMPUTE_STREAMS = set() _GRAPH_PRINT_STREAM = None _GRAPH_PRINT_STREAM_LOCK = Lock() @@ -121,22 +121,8 @@ def _unregister_print_streams_on_exit(): atexit.register(_unregister_print_streams_on_exit) -def is_310p(): - global _IS_310P - if _IS_310P is None: - from vllm_ascend import _build_info # type: ignore - _IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p") - return _IS_310P - - -def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool: - global _ENABLE_NZ - if _ENABLE_NZ is None: - if not vllm_config: - raise ValueError( - "vllm_config must be provided when _ENABLE_NZ is None") - _ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" - return _ENABLE_NZ +def is_enable_nz(): + return envs_ascend.VLLM_ASCEND_ENABLE_NZ def sleep_mode_enabled(): @@ -413,6 +399,61 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig, vllm_config.compilation_config.post_init_cudagraph_sizes() +def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool: + """ + Check whether it is vLLM default capture sizes. + """ + + max_cudagraph_capture_size = \ + vllm_config.compilation_config.max_cudagraph_capture_size + cudagraph_capture_sizes = [ + i for i in [1, 2, 4] if i <= max_cudagraph_capture_size + ] + if max_cudagraph_capture_size >= 8: + # Step size 8 for small batch sizes, up to 256(not included) + cudagraph_capture_sizes += list( + range(8, min(max_cudagraph_capture_size + 1, 256), 8)) + if max_cudagraph_capture_size >= 256: + # Step size 16 for larger batch sizes + cudagraph_capture_sizes += list( + range(256, max_cudagraph_capture_size + 1, 16)) + # in newer version, vLLM use ascending order of cudagraph_capture_sizes. + target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes) + if target_cudagraph_capture_sizes == \ + vllm_config.compilation_config.cudagraph_capture_sizes: + return True + + return False + + +def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: + """ + Update ACL graph default capture sizes, so that new sizes + are more friendly to ascend ops && hardware. + """ + + if vllm_config.model_config is None or \ + vllm_config.model_config.enforce_eager or \ + not _is_default_capture_sizes(vllm_config): + return + + # modify the default capture_sizes for Qwen3-MoE models on dp settings. + # this is mainly because performance of _npu_paged_attention might degrades + # on special shapes. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + if vllm_config.model_config and vllm_config.model_config.hf_config.model_type == "qwen3_moe" \ + and vllm_config.parallel_config.tensor_parallel_size == 1 \ + and vllm_config.parallel_config.data_parallel_size > 1 : + + max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size + new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [ + i for i in range(24, max_capture_size + 1, 8) + ] + update_cudagraph_capture_sizes(vllm_config, + new_cudagraph_capture_sizes) + + def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: """Update ACL graph capture sizes based on hardware limitations""" # NOTE: Currently, we can only capture 1800 graphs at most, @@ -504,10 +545,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: indices[0], indices[-1] = 0, len(original_sizes) - 1 sampled_sizes = [original_sizes[i] for i in indices] - if vllm_version_is("0.11.0"): - compilation_config.init_with_cudagraph_sizes(sampled_sizes) - else: - update_cudagraph_capture_sizes(vllm_config, sampled_sizes) + update_cudagraph_capture_sizes(vllm_config, sampled_sizes) logger.info( "Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes", @@ -538,10 +576,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: enlarged_sizes = [(num_speculative_tokens + 1) * size for size in original_sizes] - if vllm_version_is("0.11.0"): - compilation_config.init_with_cudagraph_sizes(enlarged_sizes) - else: - update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) + update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) logger.info( "Adjusted ACL graphs: %s → %s for speculative decoding", original_sizes, enlarged_sizes) @@ -650,11 +685,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "GemmaRMSNorm": AscendGemmaRMSNorm, "FusedMoE": AscendFusedMoE, "SharedFusedMoE": AscendSharedFusedMoE, + "MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention, } - mla_to_register = "MultiHeadLatentAttention" if vllm_version_is( - "0.11.0") else "MultiHeadLatentAttentionWrapper" - if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla: - REGISTERED_ASCEND_OPS[mla_to_register] = AscendMultiHeadLatentAttention for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) @@ -663,32 +695,47 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): _ASCEND_CUSTOMOP_IS_REIGISTERED = True -# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN. -# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does. -class AscendSocVersion(Enum): - A2 = 0 - A3 = 1 - UNDEFINED = 2 +class AscendDeviceType(Enum): + _910B = 0 # A2 + _910_93 = 1 # A3 + _310P = 2 + _910_95 = 3 # A5 + +_ascend_device_type = None -_ascend_soc_version = None +def _init_ascend_device_type(): + global _ascend_device_type + from vllm_ascend import _build_info # type: ignore + _ascend_device_type = AscendDeviceType[_build_info.__device_type__] + + +def check_ascend_device_type(): + global _ascend_device_type + if _ascend_device_type is None: + _init_ascend_device_type() -def init_ascend_soc_version(): soc_version = torch_npu.npu.get_soc_version() - global _ascend_soc_version if 220 <= soc_version <= 225: - _ascend_soc_version = AscendSocVersion.A2 + cur_device_type = AscendDeviceType._910B elif 250 <= soc_version <= 255: - _ascend_soc_version = AscendSocVersion.A3 + cur_device_type = AscendDeviceType._910_93 + elif 200 <= soc_version <= 205: + cur_device_type = AscendDeviceType._310P + elif soc_version == 260: + cur_device_type = AscendDeviceType._910_95 else: - _ascend_soc_version = AscendSocVersion.UNDEFINED + raise RuntimeError(f"Can not support soc_version: {soc_version}.") + + assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package." -def get_ascend_soc_version(): - global _ascend_soc_version - assert _ascend_soc_version is not None - return _ascend_soc_version +def get_ascend_device_type(): + global _ascend_device_type + if _ascend_device_type is None: + _init_ascend_device_type() + return _ascend_device_type def lmhead_tp_enable() -> bool: @@ -767,6 +814,15 @@ def _is_contain_expert(config: Any): return False +def is_vl_model(vllm_config: VllmConfig): + """Checks if the model is a VL model by config""" + global _IS_VL_MODEL + if _IS_VL_MODEL is None and vllm_config.model_config: + model_configs = vllm_config.model_config.hf_config.to_dict() + _IS_VL_MODEL = "VL" in model_configs["architectures"][0] + return _IS_VL_MODEL + + def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index da0cb543267..3317a2379ed 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -3,13 +3,7 @@ import numpy as np import torch from vllm.distributed import get_dcp_group - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv from vllm_ascend.utils import prefill_context_parallel_enable @@ -27,13 +21,29 @@ def __init__(self, pin_memory: bool, device: torch.device, kernel_sizes: Union[list[int], None] = None, - cp_kv_cache_interleave_size: int = 1): + cp_kv_cache_interleave_size: int = 1, + num_speculative_tokens: int = 0): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device self.physical_block_size = block_size + + try: + self.pcp_world_size = get_pcp_group( + ).world_size if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_world_size > 1 else 0 + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.pcp_world_size = 1 + self.pcp_rank = 0 + # If kernel_sizes is None or [0], use physical block size (no splitting) if kernel_sizes is None or kernel_sizes == [0]: self.block_size = block_size @@ -69,13 +79,16 @@ def __init__(self, else: logical_table_size = max_num_blocks_per_req + duplicate_size = 1 + if self.pcp_world_size > 1: + duplicate_size += num_speculative_tokens self.block_table = torch.zeros( - (max_num_reqs, logical_table_size), + (max_num_reqs * duplicate_size, logical_table_size), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, logical_table_size), + (max_num_reqs * duplicate_size, logical_table_size), device="cpu", dtype=torch.int32, pin_memory=pin_memory, @@ -83,20 +96,6 @@ def __init__(self, self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - try: - self.pcp_world_size = get_pcp_group( - ).world_size if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_world_size > 1 else 0 - self.dcp_world_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group - except AssertionError: - # DCP might not be initialized in testing - self.dcp_world_size = 1 - self.dcp_rank = 0 - self.pcp_world_size = 1 - self.pcp_rank = 0 - self.slot_mapping_cpu = torch.zeros( self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, @@ -306,7 +305,7 @@ def __init__(self, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, pin_memory, device, kernel_size_list, - cp_kv_cache_interleave_size) + cp_kv_cache_interleave_size, num_speculative_tokens) for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) ] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3c9fc126efc..2e7c4ea299b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -41,10 +41,11 @@ from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -58,8 +59,6 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model -# yapf conflicts with isort for this block -# yapf: disable from vllm.model_executor.models.interfaces import (SupportsMultiModal, supports_mrope, supports_transcription) @@ -73,29 +72,23 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import length_from_prompt_token_ids_or_embeds - -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.utils import cdiv -else: - from vllm.utils.math_utils import cdiv - +from vllm.utils.import_utils import LazyLoader from vllm.utils.jsontree import json_map_leaves +from vllm.utils.math_utils import cdiv +from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -# yapf conflicts with isort for this block -# yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) -# yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, ModelRunnerOutput, PoolerOutput) @@ -119,6 +112,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata) +# yapf conflicts with isort for this block # yapf: disable from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, @@ -144,11 +138,10 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - AscendSocVersion, ProfileExecuteDuration, - enable_sp, get_ascend_soc_version, is_310p, - is_enable_nz, is_moe_model, lmhead_tp_enable, - prefill_context_parallel_enable, - vllm_version_is) + AscendDeviceType, ProfileExecuteDuration, + enable_sp, get_ascend_device_type, is_enable_nz, + is_moe_model, lmhead_tp_enable, + prefill_context_parallel_enable) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if prefill_context_parallel_enable(): @@ -157,30 +150,9 @@ get_prefill_context_model_parallel_rank, get_prefill_context_model_parallel_world_size) -if vllm_version_is("0.11.0"): - from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - get_dtype_size) -else: - from vllm.utils.mem_utils import DeviceMemoryProfiler - from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size - -# yapf: enable - -if vllm_version_is("0.11.0"): - from vllm.attention.layer import Attention - from vllm.config import CompilationLevel - from vllm.utils import LazyLoader, is_pin_memory_available - - from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention -else: - from vllm.attention.layer import MLAAttention - from vllm.config import CompilationMode - from vllm.utils.import_utils import LazyLoader - from vllm.utils.platform_utils import is_pin_memory_available - if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -189,7 +161,7 @@ # if true, allow tensor initialization and casting with internal format (e.g., NZ) torch.npu.config.allow_internal_format = True -if is_310p(): +if get_ascend_device_type() == AscendDeviceType._310P: torch_npu.npu.set_compile_mode(jit_compile=False) ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ else: @@ -271,15 +243,32 @@ def get_output(self) -> ModelRunnerOutput: # Release the device tensor once the copy has completed del self._sampled_token_ids - valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in self._sampled_token_ids_cpu.numpy() + ] for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[i] = np.array([]) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids return output +class ExecuteModelState(NamedTuple): + """Ephemeral cached state transferred between execute_model() and + sample_tokens(), after execute_model() returns None.""" + + scheduler_output: "SchedulerOutput" + logits: torch.Tensor + spec_decode_metadata: SpecDecodeMetadata | None + hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + kv_connector_output: KVConnectorOutput | None + attn_metadata: dict[str, Any] + positions: torch.Tensor + + class NPUModelRunner(LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): @@ -294,21 +283,23 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.block_size = vllm_config.cache_config.block_size - self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, - self.block_size) - self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - decode_max_num_seqs = getattr(self.scheduler_config, - 'decode_max_num_seqs', 0) - self.max_num_reqs = max(self.scheduler_config.max_num_seqs, - decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group self.pcp_size = get_prefill_context_model_parallel_world_size( ) if prefill_context_parallel_enable() else 1 self.pcp_rank = get_prefill_context_model_parallel_rank( ) if self.pcp_size > 1 else 0 - self.dcp_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, + decode_max_num_seqs) + if self.pcp_size > 1: + self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs + self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, + self.block_size) + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.device = device if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) @@ -337,6 +328,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.intermediate_tensors: Optional[IntermediateTensors] = None self.runner_only_attn_layers: set[str] = set() + # Ascend-specific configurations self.ascend_config = get_ascend_config() if self.ascend_config.ascend_scheduler_config.enabled: self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled @@ -344,6 +336,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.chunked_prefill_enabled = True self.weight_prefetch_method = WeightPrefetchMethod( self.ascend_config.weight_prefetch_config) + # Dump / PrecisionDebugger configuration now comes from AscendConfig + dump_cfg = self.ascend_config.dump_config + self.dump_enable = dump_cfg.enable_dump + self.debugger = None + if self.dump_enable: + if self.model_config.enforce_eager: + from msprobe.pytorch import PrecisionDebugger + self.debugger = PrecisionDebugger(dump_cfg.config_path) + else: + raise RuntimeError( + "Dumping/debugging only works in eager mode.") if self.cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype @@ -521,8 +524,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): if self.speculative_config else 0) self.use_aclgraph = self._use_aclgraph() - self.aclgraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + + # self.aclgraph_batch_sizes sorts in ascending order. + if (self.compilation_config.cudagraph_capture_sizes and + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + self.aclgraph_batch_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes) self.uniform_decode_query_len = 1 if not self.speculative_config else \ 1 + self.speculative_config.num_speculative_tokens @@ -592,6 +599,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), kernel_block_sizes=[[self.vllm_config.cache_config.block_size]], cp_kv_cache_interleave_size=self.parallel_config. cp_kv_cache_interleave_size @@ -611,6 +621,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # TODO: EVS Support (Video tokens pruning) (see vllm#22980) self.is_multimodal_pruning_enabled = False + # Ephemeral state transferred between execute_model() and sample_tokens(). + self.execute_model_state: ExecuteModelState | None = None + + self.transfer_event = torch.npu.Event() + def _set_up_drafter(self): # Set up speculative decoding. self.spec_attn_mask = None @@ -628,11 +643,7 @@ def _set_up_drafter(self): diagonal=1).to(self.device) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() - if vllm_version_is("0.11.0"): - self.rejection_sampler = AscendRejectionSampler() - else: - self.rejection_sampler = AscendRejectionSampler( - self.sampler) + self.rejection_sampler = AscendRejectionSampler(self.sampler) self.actual_seq_lengths_q = list( range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req)) @@ -655,7 +666,7 @@ def _init_mc2_tokens_capacity(self): # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes, # the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512). if self.compilation_config.cudagraph_capture_sizes: - max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0] + max_num_tokens = self.compilation_config.max_cudagraph_capture_size else: # NOTE: To save memory, we cap the max number of tokens to 512. max_num_tokens = min( @@ -704,10 +715,7 @@ def _update_states_after_model_execute( self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _use_aclgraph(self) -> bool: - if vllm_version_is("0.11.0"): - return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager - else: - return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager + return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. @@ -879,51 +887,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _init_mrope_positions(self, req_state: CachedRequestState): - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - assert req_state.mm_features is not None - for mm_feature in req_state.mm_features: - mm_item = mm_feature.data - if mm_item is None: - continue - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - if vllm_version_is("0.11.0"): - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - else: - if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = \ - self.model.get_mrope_input_positions( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + assert supports_mrope(self.model), "MROPE is not supported" + req_state.mrope_positions, req_state.mrope_position_delta = \ + self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + req_state.mm_features, + ) def _sync_metadata_across_dp( self, num_tokens: int, @@ -1000,10 +969,12 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: + # pcp situation. if self.pcp_size > 1: return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") + # dcp situation. if self.dcp_size > 1: return self.attn_mask_builder.get_splitfuse_attn_mask() # Pooling situation. @@ -1011,12 +982,7 @@ def _make_attention_mask(self, seq_lens, position, return self.attn_mask_builder.get_pooling_mask(self.device) # Chunk Prefill situation. elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: - if self.dcp_size > 1: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - else: - return self.attn_mask_builder.get_splitfuse_attn_mask() + return self.attn_mask_builder.get_splitfuse_attn_mask() # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: @@ -1025,13 +991,16 @@ def _make_attention_mask(self, seq_lens, position, max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_attn_mask( - 2048, self.dtype, self.device) + return self.attn_mask_builder.get_splitfuse_attn_mask().to( + torch.bool) # Decode-only situation. else: return None def _make_fia_attention_mask(self) -> torch.Tensor: + # pcp situation. + if self.pcp_size > 1: + return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") return self.attn_mask_builder.get_splitfuse_attn_mask() @@ -1095,21 +1064,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( scheduler_output) encoder_outputs = [] - - if vllm_version_is("0.11.0"): - mm_inputs = group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ) - else: - model = cast(SupportsMultiModal, self.model) - mm_inputs = group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - ) + model = cast(SupportsMultiModal, self.model) + mm_inputs = group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) for modality, num_items, mm_kwargs_group in mm_inputs: # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1118,8 +1079,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. - curr_group_outputs = self.model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -1168,56 +1128,6 @@ def _batch_mm_kwargs_from_scheduler( return mm_kwargs, mm_hashes_pos - def _gather_mm_embeddings_0110( - self, - scheduler_output: "SchedulerOutput", - ) -> list[torch.Tensor]: - - def _iter_mm_features(req_state: CachedRequestState): - assert req_state.mm_features is not None - for mm_feature in req_state.mm_features: - pos_info = mm_feature.mm_position - yield mm_feature.identifier, pos_info, getattr( - pos_info, "is_embed", None) - - mm_embeds: list[torch.Tensor] = [] - - for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - - for mm_hash, pos_info, is_embed in _iter_mm_features(req_state): - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length - - if start_pos >= num_computed_tokens + num_scheduled_tokens: - break - if start_pos + num_encoder_tokens <= num_computed_tokens: - continue - - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens, - ) - assert start_idx < end_idx - - encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None, \ - f"Encoder cache miss for {mm_hash}." - - if is_embed is not None: - is_embed = is_embed[start_idx:end_idx] - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds.append(mm_embeds_item) - return mm_embeds - def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", @@ -1717,22 +1627,14 @@ def _prepare_inputs( # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:total_num_scheduled_tokens] - if vllm_version_is("0.11.0"): - mm_embeds = self._gather_mm_embeddings_0110(scheduler_output) - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(input_ids) - else: - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output) - inputs_embeds = self.model.get_input_embeddings( - input_ids, - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) + inputs_embeds = self.model.embed_input_ids( + input_ids, + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_( @@ -1758,7 +1660,7 @@ def _prepare_inputs( # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings( + tokens_to_embeds = self.model.embed_input_ids( input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds @@ -1918,6 +1820,31 @@ def _prepare_inputs( prefill_context_parallel_metadata=long_seq_metadata, ) + if self.speculative_config and self.pcp_size > 1: + # For pcp + spec decode, we flatten block_table + # to avoid irregular spec_attn_mask shape, e.g., + # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, + # ori block_table: # [d0, d1, p0, p1, p2] + # (num_reqs_d + num_reqs_p, max_num_blocks), + # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] + # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), + ori_query_lens = self.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ + self.query_start_loc_pcp_full_cpu[:num_reqs] + num_prefill_reqs = (ori_query_lens + > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold + blk_table_tensor[ + num_decode_reqs_flatten:num_decode_reqs_flatten + + num_prefill_reqs].copy_( + blk_table_tensor[num_decode_reqs:num_decode_reqs + + num_prefill_reqs].clone()) + blk_table_tensor[:num_decode_reqs_flatten].copy_( + blk_table_tensor[:num_decode_reqs].repeat_interleave( + self.decode_threshold, dim=0)) + common_attn_metadata.block_table_tensor = \ + blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] + if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata = common_attn_metadata @@ -2113,9 +2040,8 @@ def _calc_spec_decode_metadata( # TODO: Optimize the CPU -> NPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( self.device, non_blocking=True) - if not vllm_version_is("0.11.0"): - cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( - self.device, non_blocking=True) + cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( + self.device, non_blocking=True) logits_indices = torch.from_numpy(logits_indices).to(self.device, non_blocking=True) target_logits_indices = torch.from_numpy(target_logits_indices).to( @@ -2129,33 +2055,24 @@ def _calc_spec_decode_metadata( draft_token_ids = draft_token_ids[target_logits_indices + 1] if self.pcp_size > 1: logits_indices = logits_indices_pcp - if vllm_version_is("0.11.0"): - metadata = SpecDecodeMetadata( - draft_token_ids=draft_token_ids, - num_draft_tokens=num_draft_tokens.tolist(), - cu_num_draft_tokens=cu_num_draft_tokens, - target_logits_indices=target_logits_indices, - bonus_logits_indices=bonus_logits_indices, - logits_indices=logits_indices, - ) - else: - metadata = SpecDecodeMetadata( - draft_token_ids=draft_token_ids, - num_draft_tokens=num_draft_tokens.tolist(), - cu_num_draft_tokens=cu_num_draft_tokens, - cu_num_sampled_tokens=cu_num_sampled_tokens, - target_logits_indices=target_logits_indices, - bonus_logits_indices=bonus_logits_indices, - logits_indices=logits_indices, - ) + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + cu_num_sampled_tokens=cu_num_sampled_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) return metadata def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", + grammar_output: "GrammarOutput", logits: torch.Tensor, ) -> torch.Tensor: - grammar_bitmask = scheduler_output.grammar_bitmask + grammar_bitmask = grammar_output.grammar_bitmask # We receive the structured output bitmask from the scheduler, # compacted to contain bitmasks only for structured output requests. @@ -2174,7 +2091,7 @@ def apply_grammar_bitmask( logit_index = batch_index + cumulative_offset cumulative_offset += len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: + if req_id in grammar_output.structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] @@ -2184,33 +2101,16 @@ def apply_grammar_bitmask( shape=(logits.shape[0], grammar_bitmask.shape[1])) cumulative_index = 0 - if vllm_version_is("0.11.0"): - seq = sorted( - scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: + for req_id in grammar_output.structured_output_request_ids: + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in struct_out_req_batch_indices: logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get( - req_id, [])) for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] + sorted_bitmask[logit_index + + i] = grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - else: - for req_id in scheduler_output.structured_output_request_ids: - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get( - req_id, [])) - if req_id in struct_out_req_batch_indices: - logit_index = struct_out_req_batch_indices[req_id] - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + - i] = grammar_bitmask[cumulative_index + - i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens + cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask # Serialization of np.ndarray is much more efficient than a tensor, @@ -2232,7 +2132,7 @@ def apply_grammar_bitmask( def propose_draft_token_ids( self, - valid_sampled_token_ids: Union[torch.Tensor, list[list[int]]], + valid_sampled_token_ids: Union[torch.Tensor, list[np.ndarray]], sampling_metadata: SamplingMetadata, scheduler_output: "SchedulerOutput", spec_decode_metadata: SpecDecodeMetadata, @@ -2299,8 +2199,8 @@ def _pool( kv_connector_output=kv_connector_output, ) - def _select_moe_comm_method(self, num_tokens: int, - with_prefill: bool) -> Optional[MoECommType]: + def _select_moe_comm_method(self, + num_tokens: int) -> Optional[MoECommType]: """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all are designed for expert parallelism. 2. If expert parallel is enabled, we need to consider the soc version and the @@ -2326,14 +2226,14 @@ def _select_moe_comm_method(self, num_tokens: int, if not is_moe_model(self.vllm_config): return None - soc_version = get_ascend_soc_version() + soc_version = get_ascend_device_type() quant_type = getattr(self.vllm_config.model_config.hf_config, 'moe_quantize', None) model_type = self.vllm_config.model_config.hf_config.model_type if not self.parallel_config.enable_expert_parallel: moe_comm_type = MoECommType.ALLGATHER - elif soc_version in {AscendSocVersion.A2}: + elif soc_version in {AscendDeviceType._910B}: if (num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16): moe_comm_type = MoECommType.MC2 @@ -2344,19 +2244,13 @@ def _select_moe_comm_method(self, num_tokens: int, else: moe_comm_type = MoECommType.ALLGATHER - elif soc_version in {AscendSocVersion.A3}: + elif soc_version in {AscendDeviceType._910_93}: moe_comm_type = (MoECommType.MC2 if num_tokens <= self.mc2_tokens_capacity else MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}") - if moe_comm_type == MoECommType.ALLGATHER and with_prefill: - if enable_sp(): - moe_comm_type = MoECommType.ALLGATHER - else: - moe_comm_type = MoECommType.NAIVE_MULTICAST - # PanguProMoE only supports allgather if model_type == "PanguProMoE": moe_comm_type = MoECommType.ALLGATHER @@ -2371,7 +2265,11 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + ) -> Union[ModelRunnerOutput, IntermediateTensors] | None: + if self.execute_model_state is not None: + raise RuntimeError("State error: sample_tokens() must be called " + "after execute_model() returns None.") + with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: @@ -2396,8 +2294,19 @@ def execute_model( if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() - moe_comm_type = self._select_moe_comm_method(num_input_tokens, - self.with_prefill) + moe_comm_type = self._select_moe_comm_method(num_input_tokens) + # prevent debugger is None + need_dump = self.dump_enable and self.debugger is not None + if need_dump: + assert self.debugger is not None + dbg_cfg = getattr(self.debugger, "config", None) + dump_level = str( + getattr(dbg_cfg, "level", + "L1")).upper() if dbg_cfg is not None else "L1" + if dump_level in ("L0", "MIX"): + self.debugger.start(model=self.model) + else: + self.debugger.start() uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( scheduler_output.total_num_scheduled_tokens @@ -2455,6 +2364,10 @@ def execute_model( # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: hidden_states.kv_connector_output = kv_connector_output + if need_dump: + assert self.debugger is not None + self.debugger.stop() + self.debugger.step() return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict( @@ -2462,11 +2375,16 @@ def execute_model( logits = None else: if self.input_batch.pooling_params: - return self._pool( + pool_output = self._pool( hidden_states, scheduler_output.total_num_scheduled_tokens, num_scheduled_tokens_np, finished_sending, finished_recving, kv_connector_output) + if need_dump: + assert self.debugger is not None + self.debugger.stop() + self.debugger.step() + return pool_output sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) if broadcast_pp_output: @@ -2480,14 +2398,46 @@ def execute_model( logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present - if vllm_version_is("0.11.0"): - if scheduler_output.grammar_bitmask is not None: - logits = self.apply_grammar_bitmask( - scheduler_output, logits) - else: - if scheduler_output.structured_output_request_ids: - logits = self.apply_grammar_bitmask( - scheduler_output, logits) + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ) + return None + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + return None # noqa + need_dump = self.dump_enable and self.debugger is not None + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + logits = self.apply_grammar_bitmask(scheduler_output, + grammar_output, logits) with ProfileExecuteDuration().capture_async("Sample"): # Sample the next token and get logprobs if needed. @@ -2562,17 +2512,19 @@ def execute_model( # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + # No spec decode tokens. It's a tensor. + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in sampled_token_ids.cpu().numpy() + ] else: - # Includes spec decode tokens. + # Includes spec decode tokens. It's a numpy array valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size, ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() + valid_sampled_token_ids[int(i)] = np.array([]) else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist( @@ -2598,16 +2550,17 @@ def execute_model( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. for req_idx in range(num_sampled_tokens): + sampled_ids: np.ndarray | None if self.use_async_scheduling: - sampled_ids = [-1] * 1 if \ - req_idx not in invalid_req_indices_set else None + sampled_ids = (np.array([-1]) if req_idx + not in invalid_req_indices_set else None) else: sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: + if sampled_ids is None or sampled_ids.shape[0] == 0: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) + end_idx = start_idx + sampled_ids.shape[0] assert end_idx <= self.model_config.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " @@ -2621,7 +2574,7 @@ def execute_model( self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) + req_state.output_token_ids.extend(sampled_ids.tolist()) def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None @@ -2678,8 +2631,16 @@ def propose_draft_token_ids(sampled_token_ids): if self.dynamic_eplb: self.eplb_updator.forward_end() if not self.use_async_scheduling: + if need_dump: + assert self.debugger is not None + self.debugger.stop() + self.debugger.step() return model_runner_output + if need_dump: + assert self.debugger is not None + self.debugger.stop() + self.debugger.step() return AsyncNPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampled_token_ids, @@ -2827,6 +2788,9 @@ def _build_dummy_attn_metadata( sin=self.sin, prefill_context_parallel_metadata=long_seq_metadata, ) + if self.pcp_size > 1: + common_attn_metadata.block_table_tensor = \ + block_table_tensor[:num_reqs * self.decode_threshold] attn_state = AscendAttentionState.DecodeOnly if self.speculative_config and \ self.speculative_config.method == "deepseek_mtp": @@ -2886,8 +2850,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill, else: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, - positions.shape[0], - self.speculative_config) + num_tokens, self.speculative_config) else: if self.pcp_size * self.dcp_size > 1: update_attn_dcp_pcp_params(self.update_stream, @@ -2895,7 +2858,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill, positions.shape[0]) else: update_attn_params(self.update_stream, forward_context, - positions.shape[0]) + num_tokens) if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states @@ -2933,7 +2896,7 @@ def _dummy_run( with_prefill) = self._sync_metadata_across_dp(num_tokens, with_prefill) - moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + moe_comm_type = self._select_moe_comm_method(num_tokens) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using @@ -2975,12 +2938,14 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) if not self.in_profile_run and self.dynamic_eplb: self.eplb_updator.forward_before() with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens): + num_scheduled_tokens, + num_sampled_tokens): if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] @@ -3075,7 +3040,6 @@ def dummy_compute_logits(hidden_states): self.drafter.dummy_run( num_tokens=num_tokens, with_prefill=with_prefill, - skip_attn=True, num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode, @@ -3110,8 +3074,7 @@ def profile_run(self) -> None: # allowing vLLM to correctly estimate the maximum memory required. if self.max_num_tokens > self.mc2_tokens_capacity and \ self._select_moe_comm_method( - self.mc2_tokens_capacity, - with_prefill=True) == MoECommType.MC2: + self.mc2_tokens_capacity) == MoECommType.MC2: self._dummy_run(self.mc2_tokens_capacity, with_prefill=True) output = None @@ -3220,7 +3183,7 @@ def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) if self.dynamic_eplb: model_register(self.model, self.model_config) - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -3737,9 +3700,9 @@ def get_attn_backends_for_group( for k, v in attn_backend_layers.items() } - def create_attn_groups( - attn_backends_map: dict[AttentionBackend, list[str]], - ) -> list[AttentionGroup]: + def create_attn_groups(attn_backends_map: dict[AttentionBackend, + list[str]], + kv_cache_group_id: int) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): @@ -3750,16 +3713,17 @@ def create_attn_groups( self.vllm_config, self.device, )) - attn_group = AttentionGroup(attn_backend, - attn_metadata_builders, - layer_names, kv_cache_spec) + attn_group = AttentionGroup(attn_backend, layer_names, + kv_cache_spec, kv_cache_group_id, + attn_metadata_builders) attn_groups.append(attn_group) return attn_groups - for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): attn_backends = get_attn_backends_for_group( # type: ignore kv_cache_group_spec) - self.attn_groups.append(create_attn_groups(attn_backends)) + self.attn_groups.append(create_attn_groups(attn_backends, i)) # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() @@ -3797,95 +3761,6 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i - def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - use_sparse = self.use_sparse - kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue - if isinstance(attn_module, AscendMultiHeadLatentAttention): - continue - - # TODO: Support other attention modules, e.g., cross-attention - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if use_mla and not use_sparse: - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=self.cache_config.cache_dtype) - else: - # TODO(cmq): This is a hack way to fix deepseek kvcache when - # using DSA. Fix the spec in vLLM is a finnal way. - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. - for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), - ) - - return kv_cache_spec - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -3894,9 +3769,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ - if vllm_version_is("0.11.0"): - return self.get_kv_cache_spec_v0110() - block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} @@ -3994,8 +3866,8 @@ def initialize_aclgraph_capture(self) -> None: graph_support = builder.aclgraph_support.value builder_aclgraph = builder.aclgraph_support else: - graph_support = builder.cudagraph_support.value - builder_aclgraph = builder.cudagraph_support + graph_support = builder._cudagraph_support.value + builder_aclgraph = builder._cudagraph_support if graph_support < min_ag_support.value: min_ag_support = builder_aclgraph min_ag_builder_name = builder.__class__.__name__ @@ -4101,7 +3973,8 @@ def _capture_model(self): if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE: aclgraph_runtime_mode = aclgraph_mode.mixed_mode() - compilation_cases = sorted(self.aclgraph_batch_sizes) + # make sure we capture the largest batch size first + compilation_cases = list(reversed(self.aclgraph_batch_sizes)) try: self._capture_aclgraphs( @@ -4592,3 +4465,18 @@ def _generate_pcp_mtp_input( self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full], non_blocking=True, ) + + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return [row for row in pinned.numpy()] diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 846a4b29bc1..471c150ba62 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -30,6 +30,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, @@ -39,14 +40,8 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice -from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.block_table import MultiGroupBlockTable -if vllm_version_is("0.11.0"): - from vllm.utils import swap_dict_values -else: - from vllm.utils.collection_utils import swap_dict_values - @dataclass class CachedRequestState: @@ -834,7 +829,7 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: non_blocking=True) def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray + self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 58ac27a0d27..df7fec602d0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -18,7 +18,8 @@ # import copy -from typing import Optional, Union +from types import NoneType +from typing import Optional import torch import torch.nn as nn @@ -35,7 +36,9 @@ from vllm.lora.request import LoRARequest from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) @@ -47,10 +50,10 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz, +from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz, prefill_context_parallel_enable, register_ascend_customop, sleep_mode_enabled, - try_register_lib, vllm_version_is) + try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 @@ -65,12 +68,6 @@ torch._dynamo.trace_rules.torch_name_rule_map.append( torch_non_c_binding_in_graph_functions_npu) # noqa: E402 -if vllm_version_is("0.11.0"): - from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes -else: - from vllm.utils.mem_constants import GiB_bytes - from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE - class NPUWorker(WorkerBase): @@ -87,7 +84,6 @@ def __init__( # register patch for vllm from vllm_ascend.utils import adapt_patch adapt_patch() - is_enable_nz(vllm_config) # Register ops when worker init. from vllm_ascend import ops ops.register_dummy_fusion_op() @@ -95,7 +91,7 @@ def __init__( register_ascend_customop(vllm_config) # init ascend config and soc version init_ascend_config(vllm_config) - init_ascend_soc_version() + check_ascend_device_type() use_sparse = False if vllm_config.model_config is not None: use_sparse = hasattr(vllm_config.model_config.hf_config, @@ -142,10 +138,7 @@ def __init__( if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing - if vllm_version_is("0.11.0"): - from vllm.utils import init_cached_hf_modules - else: - from vllm.utils.import_utils import init_cached_hf_modules + from vllm.utils.import_utils import init_cached_hf_modules init_cached_hf_modules() @@ -214,6 +207,20 @@ def _init_device(self): device = torch.device(f"npu:{self.local_rank}") NPUPlatform.set_device(device) NPUPlatform.empty_cache() + + if (self.parallel_config.data_parallel_size > 1 + and self.parallel_config.data_parallel_size_local > 0 + and self.parallel_config.distributed_executor_backend + not in ["ray", "external_launcher"] and + self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1): + visible_device_count = (torch.npu.device_count() + if torch.npu.is_available() else 0) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count}).") + self.init_npu_memory = NPUPlatform.mem_get_info()[0] # Initialize the distributed environment. self._init_worker_distributed_environment() @@ -274,7 +281,7 @@ def determine_available_memory(self) -> int: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: + ) -> ModelRunnerOutput | None: # enable msMonitor to monitor the performance of vllm-ascend if envs_ascend.MSMONITOR_USE_DAEMON: dp.step() @@ -288,7 +295,7 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): + if isinstance(output, (ModelRunnerOutput, NoneType)): return output assert isinstance(output, IntermediateTensors) @@ -312,6 +319,12 @@ def execute_model( output.kv_connector_output = kv_connector_output return output + @torch.inference_mode() + def sample_tokens( + self, grammar_output: "GrammarOutput" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output) + def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: allocator = CaMemAllocator.get_instance() From 2e902004b3e3cb967927716f90f790f285a11cde Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Sat, 29 Nov 2025 15:13:05 +0800 Subject: [PATCH 07/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- .../kvpool/ascend_store_connector.py | 2 - vllm_ascend/distributed/kvpool/config_data.py | 10 +++++ vllm_ascend/distributed/kvpool/kv_transfer.py | 43 ++++++++++++------- .../distributed/kvpool/pool_scheduler.py | 7 +++ vllm_ascend/distributed/kvpool/pool_worker.py | 33 ++++++++++++-- 5 files changed, 74 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py index 9f4833555db..4107afdfab5 100644 --- a/vllm_ascend/distributed/kvpool/ascend_store_connector.py +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -43,8 +43,6 @@ def __init__(self, self.kv_caches: dict[str, torch.Tensor] = {} - self._block_size = vllm_config.cache_config.block_size - self.sended_but_unfinished_reqs: set[str] = set() if role == KVConnectorRole.SCHEDULER: diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index e3b0873d686..0d89021bb3a 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -17,6 +17,10 @@ class KeyMetadata: model_name: str """ worker id when running under a distributed setting """ head_or_tp_rank: int + """ Initialize the current prefill context model parallel rank """ + pcp_rank: int + """ Initialize the current decode context model parallel rank """ + dcp_rank: int @dataclass(order=True) @@ -28,12 +32,15 @@ def __hash__(self): return hash(( self.key_metadata.model_name, self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, self.chunk_hash, )) def to_string(self): return ( f"{self.key_metadata.model_name}" + f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" ) @@ -60,6 +67,8 @@ def __hash__(self): return hash(( self.key_metadata.model_name, self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, self.chunk_hash, self.layer_id, )) @@ -67,6 +76,7 @@ def __hash__(self): def to_string(self): return ( f"{self.key_metadata.model_name}" + f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}" ) diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index b30158ae8c2..46f37d36953 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -19,11 +19,12 @@ class KVTransferThread(threading.Thread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event, name: str): + tp_rank: int, dcp_size: int, ready_event: threading.Event, name: str): super().__init__(daemon=True, name=name) self.m_store = m_store self.ready_event = ready_event self.tp_rank = tp_rank + self.dcp_size = dcp_size self.token_database = token_database self.done_task_lock = threading.Lock() self.request_queue: queue.Queue[Any] = queue.Queue() @@ -87,10 +88,11 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, put_step: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, put_step: int, ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheSendingThread") self.put_step = put_step @@ -112,12 +114,16 @@ def _handle_request(self, req_meta: dict[str, Any]): key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] - if key_list_tp: + if self.dcp_size > 1 : torch.npu.current_stream().synchronize() - self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + self.m_store.put(key_list, addr_list, size_list) + else: + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) if is_last_chunk: self.set_finished_request(req_id) self.request_queue.task_done() @@ -126,10 +132,11 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreRecvingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreRecvingThread") @@ -166,11 +173,12 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreLayerSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, put_step: int, ready_event: threading.Event, + tp_rank: int, dcp_size: int, put_step: int, ready_event: threading.Event, num_layers: int): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreLayerSendingThread") self.final_layer_id = num_layers - 1 @@ -192,12 +200,16 @@ def _handle_request( # type: ignore[override] key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] - if key_list_tp: + if self.dcp_size > 1 : torch.npu.current_stream().synchronize() - self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + self.m_store.put(key_list, addr_list, size_list) + else: + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk: self.set_finished_request(req_meta.req_id) self.request_queue.task_done() @@ -206,11 +218,12 @@ def _handle_request( # type: ignore[override] class KVCacheStoreLayerRecvingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event, + tp_rank: int, dcp_size: int, ready_event: threading.Event, get_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread") self.get_event = get_event diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 06041b5a6e5..d1564ce7ec0 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -29,7 +29,14 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): "load_async", False) # request_id -> (vllm cached tokes, kvpool cached tokens) self.load_specs: dict[str, LoadSpec] = {} + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self._block_size = vllm_config.cache_config.block_size + if self.pcp_size > 1: + self._block_size *= self.pcp_size + if self.dcp_size > 1: + self._block_size *= self.dcp_size # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} # Whether to discard partial chunks diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index b03d2808928..31ad68d559b 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -6,6 +6,10 @@ # Third Party import torch from vllm.config import VllmConfig +from vllm.distributed import (get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.utils import logger from vllm.v1.core.kv_cache_utils import BlockHash @@ -20,6 +24,12 @@ from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) +from vllm_ascend.utils import prefill_context_parallel_enable + +if prefill_context_parallel_enable(): + from vllm.distributed import ( + get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size) backend_map: Dict[str, Type[Backend]] = { "mooncake": MooncakeBackend, @@ -44,17 +54,30 @@ def __init__( and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize - self.tp_rank = parallel_config.rank - self.tp_size = parallel_config.tensor_parallel_size + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size + + if self.pcp_size > 1: + self.block_size *= self.pcp_size + if self.dcp_size > 1: + self.block_size *= self.dcp_size self.current_layer = 0 self.num_layers = model_config.get_num_layers(parallel_config) - self.block_size = vllm_config.cache_config.block_size if self.use_mla: self.num_kv_head = 1 @@ -69,8 +92,10 @@ def __init__( self.put_step = 1 self.metadata = KeyMetadata( - model_config.model, + model_config.model.split('/')[-1], self.head_or_tp_rank, + self.pcp_rank, + self.dcp_rank, ) self.token_database = ChunkedTokenDatabase(self.metadata, From 9ff81be704167792a9dfef6e95f3b2fed58d839f Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Sat, 29 Nov 2025 17:50:19 +0800 Subject: [PATCH 08/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/kvpool/kv_transfer.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index 46f37d36953..f9330d37646 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -19,7 +19,8 @@ class KVTransferThread(threading.Thread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, dcp_size: int, ready_event: threading.Event, name: str): + tp_rank: int, dcp_size: int, ready_event: threading.Event, + name: str): super().__init__(daemon=True, name=name) self.m_store = m_store self.ready_event = ready_event @@ -88,7 +89,8 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, dcp_size: int, put_step: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, @@ -114,13 +116,15 @@ def _handle_request(self, req_meta: dict[str, Any]): key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - if self.dcp_size > 1 : + if self.dcp_size > 1: torch.npu.current_stream().synchronize() self.m_store.put(key_list, addr_list, size_list) else: key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] if key_list_tp: torch.npu.current_stream().synchronize() self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) @@ -173,8 +177,8 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreLayerSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, dcp_size: int, put_step: int, ready_event: threading.Event, - num_layers: int): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event, num_layers: int): super().__init__(m_store, token_database, tp_rank, @@ -200,13 +204,15 @@ def _handle_request( # type: ignore[override] key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - if self.dcp_size > 1 : + if self.dcp_size > 1: torch.npu.current_stream().synchronize() self.m_store.put(key_list, addr_list, size_list) else: key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] if key_list_tp: torch.npu.current_stream().synchronize() self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) From 369060b4435c6467294e3f5eab6944ee8f05c033 Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Sat, 29 Nov 2025 18:07:38 +0800 Subject: [PATCH 09/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/kvpool/kv_transfer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index f9330d37646..f956e5e1264 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -89,8 +89,8 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, dcp_size: int, put_step: int, - ready_event: threading.Event): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, @@ -121,10 +121,10 @@ def _handle_request(self, req_meta: dict[str, Any]): self.m_store.put(key_list, addr_list, size_list) else: key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % - self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % - self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] if key_list_tp: torch.npu.current_stream().synchronize() self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) @@ -209,10 +209,10 @@ def _handle_request( # type: ignore[override] self.m_store.put(key_list, addr_list, size_list) else: key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % - self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % - self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] if key_list_tp: torch.npu.current_stream().synchronize() self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) From 10a6d2d1c8884f17e443a73916aa83c5b8faf064 Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Sat, 29 Nov 2025 18:19:22 +0800 Subject: [PATCH 10/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/kvpool/kv_transfer.py | 10 +++++----- vllm_ascend/distributed/kvpool/pool_worker.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index f956e5e1264..cd5245364c4 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -90,7 +90,7 @@ class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, tp_rank: int, dcp_size: int, put_step: int, - ready_event: threading.Event): + ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, @@ -122,9 +122,9 @@ def _handle_request(self, req_meta: dict[str, Any]): else: key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] addr_list_tp = addr_list[self.tp_rank % - self.put_step::self.put_step] + self.put_step::self.put_step] size_list_tp = size_list[self.tp_rank % - self.put_step::self.put_step] + self.put_step::self.put_step] if key_list_tp: torch.npu.current_stream().synchronize() self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) @@ -210,9 +210,9 @@ def _handle_request( # type: ignore[override] else: key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] addr_list_tp = addr_list[self.tp_rank % - self.put_step::self.put_step] + self.put_step::self.put_step] size_list_tp = size_list[self.tp_rank % - self.put_step::self.put_step] + self.put_step::self.put_step] if key_list_tp: torch.npu.current_stream().synchronize() self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 31ad68d559b..fc479c2fe63 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -171,27 +171,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( - self.m_store, self.token_database, self.tp_rank, + self.m_store, self.token_database, self.tp_rank, self.dcp_size, self.put_step, ready_event_sending, self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.m_store, self.token_database, self.tp_rank, ready_event, - self.get_event) + self.dcp_size, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( - self.m_store, self.token_database, self.tp_rank, + self.m_store, self.token_database, self.tp_rank, self.dcp_size, self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.m_store, self.token_database, self.tp_rank, - ready_event) + self.dcp_size, ready_event) self.kv_recv_thread.start() ready_event.wait() From d9153ce34dd06470fd4b55fc5cf60e61fdfd90a2 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Sat, 29 Nov 2025 18:30:59 +0800 Subject: [PATCH 11/13] fix lint Signed-off-by: SlightwindSec --- vllm_ascend/distributed/kvpool/pool_worker.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index fc479c2fe63..e1116803fa2 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -171,8 +171,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( - self.m_store, self.token_database, self.tp_rank, self.dcp_size, - self.put_step, ready_event_sending, self.num_layers) + self.m_store, self.token_database, self.tp_rank, + self.dcp_size, self.put_step, ready_event_sending, + self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( @@ -184,8 +185,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( - self.m_store, self.token_database, self.tp_rank, self.dcp_size, - self.put_step, ready_event_sending) + self.m_store, self.token_database, self.tp_rank, + self.dcp_size, self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() From 5a98628b166513c30c0c12fc91d2fd59deb913ca Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Sat, 29 Nov 2025 18:34:55 +0800 Subject: [PATCH 12/13] fix lint Signed-off-by: SlightwindSec --- vllm_ascend/distributed/kvpool/pool_worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index e1116803fa2..d15e2e2578a 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -1,9 +1,7 @@ -# Standard import math import threading from typing import Dict, Generator, Optional, Type -# Third Party import torch from vllm.config import VllmConfig from vllm.distributed import (get_decode_context_model_parallel_rank, @@ -27,9 +25,11 @@ from vllm_ascend.utils import prefill_context_parallel_enable if prefill_context_parallel_enable(): - from vllm.distributed import ( - get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size) + # isort: off + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) + # isort: on backend_map: Dict[str, Type[Backend]] = { "mooncake": MooncakeBackend, From 3c7e87a973c6be036d5b0854f458b884ba892a3a Mon Sep 17 00:00:00 2001 From: fjw <2270923832@qq.com> Date: Sat, 29 Nov 2025 18:53:59 +0800 Subject: [PATCH 13/13] Pooling Features and PCP Adaptation Signed-off-by: fjw <2270923832@qq.com> --- vllm_ascend/distributed/kvpool/pool_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index d15e2e2578a..25322c5f75d 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -177,8 +177,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.m_store, self.token_database, self.tp_rank, ready_event, - self.dcp_size, self.get_event) + self.m_store, self.token_database, self.tp_rank, self.dcp_size, + ready_event, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: