diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index e38cf3ad381..33961c70a98 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -143,6 +143,10 @@ class ForwardMeta: # Flag of profile run is_dummy_or_profile_run: bool = False + # Prefill and decode flag + needs_prefill: Optional[paddle.Tensor] = None + needs_decode: Optional[paddle.Tensor] = None + def clear_caches(self): """Safely clean up the caches""" if self.caches: diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 98057c9e212..cdaeacb6873 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -126,7 +126,7 @@ def __init__(self, runnable: Callable, fd_config: FDConfig): ).__get__(self.runnable.__self__) self.cudagraph_switch_threshold = ( - 1024 if self.fd_config.graph_opt_config.graph_opt_level > 0 else self.max_captre_size + 512 if self.fd_config.graph_opt_config.graph_opt_level > 0 else self.max_captre_size ) def __call__(self, **kwargs): diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index cda5684e604..360d198520e 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -166,6 +166,16 @@ def __init__( "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." ) + if fd_config.graph_opt_config.use_cudagraph: + if fd_config.graph_opt_config.full_cuda_graph: + print( + "[Warning] Full graph capture with CUDAGraph is not supported in the presence of control flow; " + "`full_cuda_graph` has been automatically set to False." + ) + + flag = "FLAGS_cuda_graph_blacklist" + paddle.set_flags({flag: ",".join(list(set(paddle.get_flags(flag)[flag].split(",") + ["pd_op.if"])))}) + def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" metadata = MLAAttentionMetadata() @@ -205,6 +215,8 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.group_size, self.block_size, ) + forward_meta.needs_prefill = forward_meta.max_len_tensor_cpu[1] > 0 + forward_meta.needs_decode = forward_meta.max_len_tensor_cpu[2] > 0 # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 04fa0abd09b..3aadaaa040a 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -327,35 +327,22 @@ def yarn_get_mscale(scale=1, mscale=1): return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 - def forward( + @paddle.jit.marker.capture_control_flow + def mla_attention( self, - forward_meta: ForwardMeta, - hidden_states: paddle.Tensor, - position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, + forward_meta, + needs_prefill, + needs_decode, + compressed_kv, + query, + query_pe, + key_pe, + mask_encoder_batch, + query_nope, + output, ): - """ """ - - # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. - fmha_out = None - - # NOTE: (changwenbin) qkv_a_proj horizontal fusion - qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) - query, compressed_kv, key_pe = qkv_a_out.split( - [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 - ) - query = self.q_a_layernorm(query)[0] - query = self.q_b_proj(query) - query.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) - query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) - - key_pe.reshape_([-1, 1, self.qk_rope_head_dim]) - query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) - - compressed_kv = self.kv_a_layernorm(compressed_kv)[0] - - if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time + if needs_prefill: key_value = self.kv_b_proj(compressed_kv) key_value.reshape_( [ @@ -381,15 +368,14 @@ def forward( k_pe=key_pe, forward_meta=forward_meta, ) - - fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) + fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) + else: + fmha_out_prefill = paddle.zeros_like(output) - fmha_out = fmha_out_prefill - - if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time + if needs_decode: q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) @@ -418,10 +404,55 @@ def forward( .transpose([1, 0, 2]) .reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) ) - if fmha_out is None: - fmha_out = fmha_out_decode - else: - fmha_out = fmha_out + fmha_out_decode + output = paddle.assign(fmha_out_prefill + fmha_out_decode, output) + else: + output = paddle.assign(fmha_out_prefill, output) + + return output + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + position_ids: paddle.Tensor, + mask_encoder_batch: paddle.Tensor, + ): + """ """ + + # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. + # NOTE: (changwenbin) qkv_a_proj horizontal fusion + qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) + query, compressed_kv, key_pe = qkv_a_out.split( + [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 + ) + + query = self.q_a_layernorm(query)[0] + query = self.q_b_proj(query) + query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) + query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + + key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) + compressed_kv = self.kv_a_layernorm(compressed_kv)[0] + + query_pe = paddle.assign(query_pe) + key_pe = paddle.assign(key_pe) + query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + + bs = query.shape[0] + fmha_out = paddle.zeros([bs, self.num_attention_heads_tp * self.v_head_dim], dtype=query.dtype) + + fmha_out = self.mla_attention( + forward_meta, + forward_meta.needs_prefill, + forward_meta.needs_decode, + compressed_kv, + query, + query_pe, + key_pe, + mask_encoder_batch, + query_nope, + fmha_out, + ) output = self.o_proj(fmha_out) return output