Skip to content

Commit 739002f

Browse files
committed
fix
1 parent 5bafd8a commit 739002f

File tree

7 files changed

+63
-50
lines changed

7 files changed

+63
-50
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def __init__(
530530
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
531531
self.key_cache = None
532532
self.value_cache = None
533-
self.block_size = VllmConfig().cache_config.block_size
534533
self.pcp_size = get_prefill_context_model_parallel_world_size(
535534
) if prefill_context_parallel_enable() else 1
536535
self.pcp_rank = get_prefill_context_model_parallel_rank(
@@ -575,12 +574,12 @@ def _forward_prefill_no_cache(
575574
output,_ = torch_npu.npu_fused_infer_attention_score_v2(
576575
query[:num_tokens],
577576
key[:num_tokens],
578-
atten_mask=maks.to(torch.bool),
577+
atten_mask=mask.to(torch.bool),
579578
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
580-
actual_seq_kvlen=attn_metadata.seq_lens.cumcum(0),
579+
actual_seq_kvlen=attn_metadata.seq_lens.cumsum(0),
581580
num_query_heads=self.num_heads,
582581
num_key_value_heads=self.num_kv_heads,
583-
imput_layout="TND",
582+
input_layout="TND",
584583
softmax_scale=self.scale
585584
)
586585
assert output is not None
@@ -615,24 +614,24 @@ def _forward_prefill_cache_hit(
615614

616615
if is_A5():
617616
compress_mask = compress_mask.to(torch.bool)
618-
key = self.key_cache.transpos(1,2)
619-
value = slef.value_cache.transpose(1,2)
617+
key = self.key_cache.transpose(1,2)
618+
value = self.value_cache.transpose(1,2)
620619
block_size = self.block_size
621620

622-
output, _ = troch_npu.npu_fused_infer_attention_score_v2(
621+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
623622
query=query,
624623
key=key,
625624
value=value,
626625
block_table=block_table,
627-
atten_mask=mask,
626+
atten_mask=compress_mask,
628627
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
629628
actual_seq_kvlen=attn_metadata.seq_lens,
630629
num_query_heads=self.num_heads,
631630
num_key_value_heads=self.num_kv_heads,
632631
softmax_scale=self.scale,
633632
spare_mode=2, #spare_mode=2时,代表leftupCausal模式的mask
634633
block_size=block_size,
635-
imput_layout="TND"
634+
input_layout="TND"
636635
)
637636
return output
638637

@@ -768,23 +767,24 @@ def _forward_decode_only(
768767
else:
769768
if is_A5():
770769
batch_size = attn_metadata.query_lens.shape[0]
771-
hidden_szie = self.num_heads * self.head_size
772-
query = query[:batch_szie]
770+
hidden_size = self.num_heads * self.head_size
771+
query = query[:batch_size]
773772
query = query.view(batch_size, 1, hidden_size)
774773
block_size = self.key_cache.shape[1]
775774
key = self.key_cache.flatten(2, 3).contiguous()
775+
value = self.value_cache.flatten(2, 3).contiguous()
776776
ori_output = output
777-
output, _ = torch_nup.npu_fused_infer_attention_score_v2(
777+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
778778
query=query,
779779
key=key,
780780
value=value,
781-
actual_seq_kvlen=attn_metadata.seq_len,
781+
actual_seq_kvlen=attn_metadata.seq_lens,
782782
num_query_heads=self.num_heads,
783783
num_key_value_heads=self.num_kv_heads,
784-
block_table=attn_metadata.block_tables[:batch_szie],
784+
block_table=attn_metadata.block_tables[:batch_size],
785785
block_size=block_size,
786786
softmax_scale=self.scale,
787-
inpt_layout="BSH"
787+
input_layout="BSH"
788788
)
789789
output = output.view(-1, self.num_heads, self.head_size)
790790
ori_output[:batch_size] = output[:batch_size]
@@ -859,9 +859,9 @@ def _forward_v1_style(
859859
num_query_heads=self.num_heads,
860860
num_key_value_heads=self.num_kv_heads,
861861
block_table=attn_metadata.block_tables[:attn_metadata.query_lens.shape[0]],
862-
block_size=self.key_cache.shape[1],
862+
block_size=self.key_cache.shape[1],
863863
softmax_scale=self.scale,
864-
imput_layout="TND"
864+
input_layout="TND"
865865
)
866866
return output
867867
output, _ = torch_npu.npu_fused_infer_attention_score(
@@ -1611,23 +1611,24 @@ def forward(
16111611
if is_A5(): # 这里代码变动较大需要重新适配
16121612
num_token = slots.shape[0]
16131613
torch_npu.npu_scatter_a_kv_cache(
1614-
key=key[:num_tokens],
1615-
value=value[:num_tokens],
1616-
slot_mapping=slots,
1614+
key=key[self.pcp_size * num_decode_tokens:attn_metadata.num_actual_tokens_pcp_padded],
1615+
value=value[self.pcp_size * num_decode_tokens:attn_metadata.num_actual_tokens_pcp_padded],
1616+
slot_mapping=slot_mapping[self.pcp_size * num_decode_tokens:attn_metadata.num_actual_tokens_pcp_padded]
16171617
out=(self.key_cache, slef.value_cache)
16181618
)
1619-
torch_npu._npu_reshape_and_cache(
1620-
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
1621-
num_actual_tokens_pcp_padded],
1622-
value=value[self.pcp_size *
1623-
num_decode_tokens:attn_metadata.
1619+
else:
1620+
torch_npu._npu_reshape_and_cache(
1621+
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
16241622
num_actual_tokens_pcp_padded],
1625-
key_cache=self.key_cache,
1626-
value_cache=self.value_cache,
1627-
slot_indices=attn_metadata.
1628-
slot_mapping[self.pcp_size *
1629-
num_decode_tokens:attn_metadata.
1630-
num_actual_tokens_pcp_padded])
1623+
value=value[self.pcp_size *
1624+
num_decode_tokens:attn_metadata.
1625+
num_actual_tokens_pcp_padded],
1626+
key_cache=self.key_cache,
1627+
value_cache=self.value_cache,
1628+
slot_indices=attn_metadata.
1629+
slot_mapping[self.pcp_size *
1630+
num_decode_tokens:attn_metadata.
1631+
num_actual_tokens_pcp_padded])
16311632

16321633
if self.pcp_size * self.dcp_size > 1:
16331634
intermediate_output = self._forward_pcp_dcp(

vllm_ascend/attention/mla_v1.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,8 @@ def _compute_prefill_context(
932932
cache_k_pe = kv_c_and_k_pe_cache[1]
933933
num_heads = cache_k_pe.size(2)
934934
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
935+
all_prefix_lse = [prefix_lse.view(-1)]
936+
all_prefix_output = [prefix_output.view(-1, q_nope.shape[-1])]
935937
# token -> request mapping for building per-token masks when CP>1
936938
seq_len1 = torch.tensor(prefill_metadata.query_lens,
937939
dtype=torch.int32,
@@ -990,17 +992,27 @@ def _compute_prefill_context(
990992
rope_dim,
991993
dtype=q_nope.dtype,
992994
device=q_nope.device)
993-
994-
torch_npu.atb.npu_paged_cache_load(
995-
cache_kv_c,
996-
cache_k_pe,
997-
prefill_metadata.block_table,
998-
seq_len2_rank.to(q_nope.device),
999-
seq_starts=
1000-
context_starts_rank, # slot offsets of current chunk in current iteration
1001-
key=kv_c_normed,
1002-
value=k_pe,
1003-
)
995+
if is_A5():
996+
torch_npu.npu_gather_pa_kv_cache(
997+
cache_kv_c,
998+
cache_k_pe,
999+
prefill_metadata.block_table,
1000+
context_seq_len_npu,
1001+
key=kv_c_normed,
1002+
value=k_pe,
1003+
seq_offset=prefill_metadata.chunked_context.starts[i],
1004+
)
1005+
else:
1006+
torch_npu.atb.npu_paged_cache_load(
1007+
cache_kv_c,
1008+
cache_k_pe,
1009+
prefill_metadata.block_table,
1010+
seq_len2_rank.to(q_nope.device),
1011+
seq_starts=
1012+
context_starts_rank, # slot offsets of current chunk in current iteration
1013+
key=kv_c_normed,
1014+
value=k_pe,
1015+
)
10041016
seq_len2 = seq_len2_rank.to(q_nope.device)
10051017
else:
10061018
# If current rank has no tokens to process, create empty tensors

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def get_device_info(self, global_rank_table, device_filter, device_type):
503503
and device_filter(d.get("device_id", ""))
504504
]
505505
if len(device_list) <= self.pcp_rank * self.tp_size + self.tp_rank:
506-
retunr None
506+
return None
507507
device_info = device_list[self.pcp_rank * self.tp_size + self.tp_rank]
508508
return device_info
509509

@@ -531,7 +531,7 @@ def read_agent_metadata(self, global_rank_table):
531531
agent_metadata = LLMDataDistCMgrAgentMetadataA5(
532532
server_id=server_id_,
533533
device_id=device_id_,
534-
device_ip=device_ip_,
534+
device_ip=device_id_,
535535
cluster_id=cluster_id_,
536536
level_list = level_list_,
537537
)

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _select_experts_with_fusion_ops(
198198
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
199199
if is_A5():
200200
# A5 MOCK
201-
new_shape = router_logits.shape[-1] + (topk,)
201+
new_shape = router_logits.shape[:-1] + (topk,)
202202
topk_weights = torch.ones(new_shape, dtype=router_logits.dtype, device=router_logits.device)
203203
topk_ids = torch.zeros(topk_weights.shape, dtype=torch.int32, device=router_logits.device)
204204
else :

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ def forward_oot(
421421
query.dtype) # type: ignore
422422

423423
if is_A5(): # A5不支持npu_mrope算子,这里需要使用小算子替换
424-
return
425-
424+
return query, key
425+
426426
query, key = torch_npu.npu_mrope(positions,
427427
query.contiguous(),
428428
key.contiguous(),

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,8 @@ def _make_attention_mask(self, seq_lens, position,
10761076
# Prefill without cache situation.
10771077
elif attn_state == AscendAttentionState.PrefillNoCache:
10781078
if is_A5():
1079-
mas_seq_len = max(seq_lens, default=0)
1080-
max_seq_len = (max_seq_len + self.block_szie - 1) // self.block_size * self.block_size
1079+
max_seq_len = max(seq_lens, default=0)
1080+
max_seq_len = (max_seq_len + self.block_size - 1) // self.block_size * self.block_size
10811081
new_element = torch.tensor([max_seq_len])
10821082
seq_lens = torch.cat([seq_lens, new_element], dim =0)
10831083
return self.attn_mask_builder.get_attn_mask(max_seq_len, self.dtype, self.device).to(torch.bool)

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def compile_or_warm_up_model(self) -> None:
338338
self.model_runner.capture_model()
339339
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
340340
# may cause performance degradation at runtime.
341-
if ~is_A5():
341+
if not is_A5():
342342
self._warm_up_atb()
343343
# Reset the seed to ensure that the random state is not affected by
344344
# the model initialization and profiling.

0 commit comments

Comments
 (0)