-
Notifications
You must be signed in to change notification settings - Fork 662
[Graph Optimization] Support deepseekV3 SOT Dy2St && CUDAGraph #4785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 1 commit
29ec828
effa7f4
57bd540
05416ae
e397a1f
e5c7a13
ac349b7
8d64fe6
e8fd411
8422a36
79b03c7
c8fa3f1
4aa507b
8ae652d
7ca7efa
7ad94cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -326,36 +326,22 @@ def yarn_get_mscale(scale=1, mscale=1): | |
| return 1.0 | ||
| return 0.1 * mscale * math.log(scale) + 1.0 | ||
|
|
||
| def forward( | ||
| @paddle.jit.marker.capture_control_flow | ||
| def prefill_or_decode( | ||
| self, | ||
| forward_meta: ForwardMeta, | ||
| hidden_states: paddle.Tensor, | ||
| position_ids: paddle.Tensor, | ||
| mask_encoder_batch: paddle.Tensor, | ||
| forward_meta, | ||
| max_enc_len_this_time, | ||
| max_dec_len_this_time, | ||
| compressed_kv, | ||
| query, | ||
| query_pe, | ||
| key_pe, | ||
| mask_encoder_batch, | ||
| query_nope, | ||
| ): | ||
| """ """ | ||
|
|
||
| # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. | ||
| fmha_out = None | ||
|
|
||
| # NOTE: (changwenbin) qkv_a_proj horizontal fusion | ||
| qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) | ||
| query, compressed_kv, key_pe = qkv_a_out.split( | ||
| [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 | ||
| ) | ||
|
|
||
| query = self.q_a_layernorm(query) | ||
| query = self.q_b_proj(query) | ||
| query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
| query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) | ||
|
|
||
| key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) | ||
| compressed_kv = self.kv_a_layernorm(compressed_kv) | ||
|
|
||
| query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) | ||
|
|
||
| if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time | ||
| key_value = self.kv_b_proj(compressed_kv) | ||
| if max_enc_len_this_time: | ||
| key_value = self.kv_b_proj(compressed_kv) # 这部分 | ||
| key_value = key_value.reshape( | ||
| [ | ||
| -1, | ||
|
|
@@ -380,15 +366,16 @@ def forward( | |
| k_pe=key_pe, | ||
| forward_meta=forward_meta, | ||
| ) | ||
| else: | ||
| fmha_out_prefill = paddle.zeros_like(query) | ||
|
|
||
| fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
|
||
| fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] | ||
| fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
| fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) | ||
|
|
||
| fmha_out = fmha_out_prefill | ||
| # TODO(drryanhuang): rm this redundant reshape when fmha_out_prefill is zero | ||
| fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
| fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] | ||
| fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
| fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) | ||
|
|
||
| if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time | ||
| if max_dec_len_this_time: | ||
| q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) | ||
|
|
||
| q_input = paddle.concat([q_nope_out, query_pe], axis=-1) | ||
|
|
@@ -417,10 +404,53 @@ def forward( | |
| .transpose([1, 0, 2]) | ||
| .reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
| ) | ||
| if fmha_out is None: | ||
| fmha_out = fmha_out_decode | ||
| else: | ||
| fmha_out = fmha_out + fmha_out_decode | ||
| fmha_out = fmha_out_prefill + fmha_out_decode | ||
| else: | ||
| fmha_out = fmha_out_prefill | ||
|
|
||
| return fmha_out | ||
|
|
||
| def forward( | ||
| self, | ||
| forward_meta: ForwardMeta, | ||
| hidden_states: paddle.Tensor, | ||
| position_ids: paddle.Tensor, | ||
| mask_encoder_batch: paddle.Tensor, | ||
| ): | ||
| """ """ | ||
|
|
||
| # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. | ||
|
|
||
| # NOTE: (changwenbin) qkv_a_proj horizontal fusion | ||
| qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) | ||
| query, compressed_kv, key_pe = qkv_a_out.split( | ||
| [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 | ||
| ) | ||
|
|
||
| query = self.q_a_layernorm(query) | ||
| query = self.q_b_proj(query) | ||
| query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
| query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) | ||
|
|
||
| key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) | ||
| compressed_kv = self.kv_a_layernorm(compressed_kv) | ||
|
|
||
| query_pe = paddle.assign(query_pe) | ||
| key_pe = paddle.assign(key_pe) | ||
|
|
||
| query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) | ||
|
|
||
| fmha_out = self.prefill_or_decode( | ||
| forward_meta, | ||
| forward_meta.max_len_tensor_cpu[1], # max_enc_len_this_time | ||
| forward_meta.max_len_tensor_cpu[2], # max_dec_len_this_time | ||
| compressed_kv, | ||
| query, | ||
| query_pe, | ||
| key_pe, | ||
| mask_encoder_batch, | ||
| query_nope, | ||
| ) | ||
|
|
||
| output = self.o_proj(fmha_out) | ||
| return output | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确定要叫这个名字么?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chang-wenbin 这个改成什么名字合适一些呢?
你在 prefill 和 decode 之间,选择了 or ? 🤪There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议就改成mla_attention吧;
但是这个改动后和之前代码的区别是啥呀,这样不会有控制流吗?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样是将
mla_attention函数做AST局部转静,这样 if 控制流语句就可以转化为if算子(SOT转静不支持Tensor做if的条件),这样做 Deepseek V3 就可以整图转静下图是IR图中的if算子,红框内部的是
mla_attention中的两个 if Op,两个红框之间的是 CUDAGraph Op