@@ -530,7 +530,6 @@ def __init__(
530530 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
531531 self .key_cache = None
532532 self .value_cache = None
533- self .block_size = VllmConfig ().cache_config .block_size
534533 self .pcp_size = get_prefill_context_model_parallel_world_size (
535534 ) if prefill_context_parallel_enable () else 1
536535 self .pcp_rank = get_prefill_context_model_parallel_rank (
@@ -575,12 +574,12 @@ def _forward_prefill_no_cache(
575574 output ,_ = torch_npu .npu_fused_infer_attention_score_v2 (
576575 query [:num_tokens ],
577576 key [:num_tokens ],
578- atten_mask = maks .to (torch .bool ),
577+ atten_mask = mask .to (torch .bool ),
579578 actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
580- actual_seq_kvlen = attn_metadata .seq_lens .cumcum (0 ),
579+ actual_seq_kvlen = attn_metadata .seq_lens .cumsum (0 ),
581580 num_query_heads = self .num_heads ,
582581 num_key_value_heads = self .num_kv_heads ,
583- imput_layout = "TND" ,
582+ input_layout = "TND" ,
584583 softmax_scale = self .scale
585584 )
586585 assert output is not None
@@ -615,24 +614,24 @@ def _forward_prefill_cache_hit(
615614
616615 if is_A5 ():
617616 compress_mask = compress_mask .to (torch .bool )
618- key = self .key_cache .transpos (1 ,2 )
619- value = slef .value_cache .transpose (1 ,2 )
617+ key = self .key_cache .transpose (1 ,2 )
618+ value = self .value_cache .transpose (1 ,2 )
620619 block_size = self .block_size
621620
622- output , _ = troch_npu .npu_fused_infer_attention_score_v2 (
621+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
623622 query = query ,
624623 key = key ,
625624 value = value ,
626625 block_table = block_table ,
627- atten_mask = mask ,
626+ atten_mask = compress_mask ,
628627 actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
629628 actual_seq_kvlen = attn_metadata .seq_lens ,
630629 num_query_heads = self .num_heads ,
631630 num_key_value_heads = self .num_kv_heads ,
632631 softmax_scale = self .scale ,
633632 spare_mode = 2 , #spare_mode=2时,代表leftupCausal模式的mask
634633 block_size = block_size ,
635- imput_layout = "TND"
634+ input_layout = "TND"
636635 )
637636 return output
638637
@@ -768,23 +767,24 @@ def _forward_decode_only(
768767 else :
769768 if is_A5 ():
770769 batch_size = attn_metadata .query_lens .shape [0 ]
771- hidden_szie = self .num_heads * self .head_size
772- query = query [:batch_szie ]
770+ hidden_size = self .num_heads * self .head_size
771+ query = query [:batch_size ]
773772 query = query .view (batch_size , 1 , hidden_size )
774773 block_size = self .key_cache .shape [1 ]
775774 key = self .key_cache .flatten (2 , 3 ).contiguous ()
775+ value = self .value_cache .flatten (2 , 3 ).contiguous ()
776776 ori_output = output
777- output , _ = torch_nup .npu_fused_infer_attention_score_v2 (
777+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
778778 query = query ,
779779 key = key ,
780780 value = value ,
781- actual_seq_kvlen = attn_metadata .seq_len ,
781+ actual_seq_kvlen = attn_metadata .seq_lens ,
782782 num_query_heads = self .num_heads ,
783783 num_key_value_heads = self .num_kv_heads ,
784- block_table = attn_metadata .block_tables [:batch_szie ],
784+ block_table = attn_metadata .block_tables [:batch_size ],
785785 block_size = block_size ,
786786 softmax_scale = self .scale ,
787- inpt_layout = "BSH"
787+ input_layout = "BSH"
788788 )
789789 output = output .view (- 1 , self .num_heads , self .head_size )
790790 ori_output [:batch_size ] = output [:batch_size ]
@@ -859,9 +859,9 @@ def _forward_v1_style(
859859 num_query_heads = self .num_heads ,
860860 num_key_value_heads = self .num_kv_heads ,
861861 block_table = attn_metadata .block_tables [:attn_metadata .query_lens .shape [0 ]],
862- block_size = self .key_cache .shape [1 ],
862+ block_size = self .key_cache .shape [1 ],
863863 softmax_scale = self .scale ,
864- imput_layout = "TND"
864+ input_layout = "TND"
865865 )
866866 return output
867867 output , _ = torch_npu .npu_fused_infer_attention_score (
@@ -1611,23 +1611,24 @@ def forward(
16111611 if is_A5 (): # 这里代码变动较大需要重新适配
16121612 num_token = slots .shape [0 ]
16131613 torch_npu .npu_scatter_a_kv_cache (
1614- key = key [: num_tokens ],
1615- value = value [: num_tokens ],
1616- slot_mapping = slots ,
1614+ key = key [self . pcp_size * num_decode_tokens : attn_metadata . num_actual_tokens_pcp_padded ],
1615+ value = value [self . pcp_size * num_decode_tokens : attn_metadata . num_actual_tokens_pcp_padded ],
1616+ slot_mapping = slot_mapping [ self . pcp_size * num_decode_tokens : attn_metadata . num_actual_tokens_pcp_padded ]
16171617 out = (self .key_cache , slef .value_cache )
16181618 )
1619- torch_npu ._npu_reshape_and_cache (
1620- key = key [self .pcp_size * num_decode_tokens :attn_metadata .
1621- num_actual_tokens_pcp_padded ],
1622- value = value [self .pcp_size *
1623- num_decode_tokens :attn_metadata .
1619+ else :
1620+ torch_npu ._npu_reshape_and_cache (
1621+ key = key [self .pcp_size * num_decode_tokens :attn_metadata .
16241622 num_actual_tokens_pcp_padded ],
1625- key_cache = self .key_cache ,
1626- value_cache = self .value_cache ,
1627- slot_indices = attn_metadata .
1628- slot_mapping [self .pcp_size *
1629- num_decode_tokens :attn_metadata .
1630- num_actual_tokens_pcp_padded ])
1623+ value = value [self .pcp_size *
1624+ num_decode_tokens :attn_metadata .
1625+ num_actual_tokens_pcp_padded ],
1626+ key_cache = self .key_cache ,
1627+ value_cache = self .value_cache ,
1628+ slot_indices = attn_metadata .
1629+ slot_mapping [self .pcp_size *
1630+ num_decode_tokens :attn_metadata .
1631+ num_actual_tokens_pcp_padded ])
16311632
16321633 if self .pcp_size * self .dcp_size > 1 :
16331634 intermediate_output = self ._forward_pcp_dcp (
0 commit comments