Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 68 additions & 38 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,36 +326,22 @@ def yarn_get_mscale(scale=1, mscale=1):
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

def forward(
@paddle.jit.marker.capture_control_flow
def prefill_or_decode(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确定要叫这个名字么?

Copy link
Collaborator Author

@DrRyanHuang DrRyanHuang Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chang-wenbin 这个改成什么名字合适一些呢?

你在 prefill 和 decode 之间,选择了 or ? 🤪

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议就改成mla_attention吧;
但是这个改动后和之前代码的区别是啥呀,这样不会有控制流吗?

Copy link
Collaborator Author

@DrRyanHuang DrRyanHuang Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样是将 mla_attention 函数做AST局部转静,这样 if 控制流语句就可以转化为if算子(SOT转静不支持Tensor做if的条件),这样做 Deepseek V3 就可以整图转静

下图是IR图中的if算子,红框内部的是 mla_attention 中的两个 if Op,两个红框之间的是 CUDAGraph Op

d44554f238c8e08f2c88c41084ac3e85

self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
forward_meta,
max_enc_len_this_time,
max_dec_len_this_time,
compressed_kv,
query,
query_pe,
key_pe,
mask_encoder_batch,
query_nope,
):
""" """

# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
fmha_out = None

# NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)

query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)

query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)

if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
key_value = self.kv_b_proj(compressed_kv)
if max_enc_len_this_time:
key_value = self.kv_b_proj(compressed_kv) # 这部分
key_value = key_value.reshape(
[
-1,
Expand All @@ -380,15 +366,16 @@ def forward(
k_pe=key_pe,
forward_meta=forward_meta,
)
else:
fmha_out_prefill = paddle.zeros_like(query)

fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里check过输出shape已经是[-1, self.num_attention_heads_tp, self.qk_head_dim]了吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里的 querykeyfmha_out_prefill 都是同 shape
都是 [bs, self.num_attention_heads_tp, self.qk_head_dim]

fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)

fmha_out = fmha_out_prefill
# TODO(drryanhuang): rm this redundant reshape when fmha_out_prefill is zero
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)

if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
if max_dec_len_this_time:
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])

q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
Expand Down Expand Up @@ -417,10 +404,53 @@ def forward(
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
)
if fmha_out is None:
fmha_out = fmha_out_decode
else:
fmha_out = fmha_out + fmha_out_decode
fmha_out = fmha_out_prefill + fmha_out_decode
else:
fmha_out = fmha_out_prefill

return fmha_out

def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
""" """

# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.

# NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)

query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)

query_pe = paddle.assign(query_pe)
key_pe = paddle.assign(key_pe)

query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)

fmha_out = self.prefill_or_decode(
forward_meta,
forward_meta.max_len_tensor_cpu[1], # max_enc_len_this_time
forward_meta.max_len_tensor_cpu[2], # max_dec_len_this_time
compressed_kv,
query,
query_pe,
key_pe,
mask_encoder_batch,
query_nope,
)

output = self.o_proj(fmha_out)
return output
Expand Down
Loading