Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ class ForwardMeta:
# Flag of profile run
is_dummy_or_profile_run: bool = False

# Prefill and decode flag
needs_prefill: Optional[paddle.Tensor] = None
needs_decode: Optional[paddle.Tensor] = None

def clear_caches(self):
"""Safely clean up the caches"""
if self.caches:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, runnable: Callable, fd_config: FDConfig):
).__get__(self.runnable.__self__)

self.cudagraph_switch_threshold = (
1024 if self.fd_config.graph_opt_config.graph_opt_level > 0 else self.max_captre_size
512 if self.fd_config.graph_opt_config.graph_opt_level > 0 else self.max_captre_size
)

def __call__(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def __init__(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)

if fd_config.graph_opt_config.use_cudagraph:
if fd_config.graph_opt_config.full_cuda_graph:
print(
"[Warning] Full graph capture with CUDAGraph is not supported in the presence of control flow; "
"`full_cuda_graph` has been automatically set to False."
)

flag = "FLAGS_cuda_graph_blacklist"
paddle.set_flags({flag: ",".join(list(set(paddle.get_flags(flag)[flag].split(",") + ["pd_op.if"])))})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里仍然需要在设置 full_cuda_graph=false 的时候手动在外面设置 FLAGS_cuda_graph_blacklist 么?用户可以做到只设置 full_cuda_graph=false 就可以跑么?

Copy link
Collaborator Author

@DrRyanHuang DrRyanHuang Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在这个就是只要设置 full_cuda_graph=false 就行,默认会有 pd_op.if

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于其他模型应该不对吧,应该不是 pd_op.if

Copy link
Collaborator Author

@DrRyanHuang DrRyanHuang Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是 MLA 的后端,里面有 if 语句,目前只给 Deepseek V3 用,其他 Attention 后端不一定有 if 语句

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他模型目前直接开启 full_cuda_graph=false 是不能做到直接能跑的是吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能跑,大部分模型这个参数都不生效,只有 append attention 后端(ERNIE4.5Turbo)会受影响
后面单独提 PR 给 append attention 后端加上这个 paddle.set/get flag


def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
Expand Down Expand Up @@ -205,6 +215,8 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.group_size,
self.block_size,
)
forward_meta.needs_prefill = forward_meta.max_len_tensor_cpu[1] > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个参数能否直接用之前的逻辑?

Copy link
Collaborator Author

@DrRyanHuang DrRyanHuang Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果 不开CUDAGraph,只开SOT转静 可以复用之前的逻辑
开CUDAGraph+SOT转静 不能复用之前的逻辑,原因如下:

之前的代码是:

if forward_meta.max_len_tensor_cpu[0]:
	...

这里包含了两个隐形操作:

  • 对这个 CPU 的 int Tensor max_len_tensor_cpu 做索引操作
  • max_len_tensor_cpu[0] 这个 CPU 的 int Scalar,cast 为 bool Scalar

这两个操作都是 CPU Kernel 的操作,由于是 if op 之前的算子,所以这俩算子会变成 CUDAGraph OP 的子 OP
由于是 CPU Kernel 所以无法被 CUDAGraph Capture 到

这样会导致 max_len_tensor_cpu[0] == 0 时,依然会进入 prefill 的分支

改成这样之后:

forward_meta.needs_prefill = forward_meta.max_len_tensor_cpu[1] > 0

if forward_meta.needs_prefill:
	....

保证 if op 能选择正确的分支

forward_meta.needs_decode = forward_meta.max_len_tensor_cpu[2] > 0

# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
Expand Down
101 changes: 66 additions & 35 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,35 +327,22 @@ def yarn_get_mscale(scale=1, mscale=1):
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

def forward(
@paddle.jit.marker.capture_control_flow
def mla_attention(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
forward_meta,
needs_prefill,
needs_decode,
compressed_kv,
query,
query_pe,
key_pe,
mask_encoder_batch,
query_nope,
output,
):
""" """

# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
fmha_out = None

# NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)

query = self.q_a_layernorm(query)[0]
query = self.q_b_proj(query)
query.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

key_pe.reshape_([-1, 1, self.qk_rope_head_dim])
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)

compressed_kv = self.kv_a_layernorm(compressed_kv)[0]

if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
if needs_prefill:
key_value = self.kv_b_proj(compressed_kv)
key_value.reshape_(
[
Expand All @@ -381,15 +368,14 @@ def forward(
k_pe=key_pe,
forward_meta=forward_meta,
)

fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
else:
fmha_out_prefill = paddle.zeros_like(output)

fmha_out = fmha_out_prefill

if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
if needs_decode:
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])

q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
Expand Down Expand Up @@ -418,10 +404,55 @@ def forward(
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
)
if fmha_out is None:
fmha_out = fmha_out_decode
else:
fmha_out = fmha_out + fmha_out_decode
output = paddle.assign(fmha_out_prefill + fmha_out_decode, output)
else:
output = paddle.assign(fmha_out_prefill, output)

return output

def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
""" """

# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)

query = self.q_a_layernorm(query)[0]
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)[0]

query_pe = paddle.assign(query_pe)
key_pe = paddle.assign(key_pe)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)

bs = query.shape[0]
fmha_out = paddle.zeros([bs, self.num_attention_heads_tp * self.v_head_dim], dtype=query.dtype)

fmha_out = self.mla_attention(
forward_meta,
forward_meta.needs_prefill,
forward_meta.needs_decode,
compressed_kv,
query,
query_pe,
key_pe,
mask_encoder_batch,
query_nope,
fmha_out,
)

output = self.o_proj(fmha_out)
return output
Expand Down
Loading