Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ def forward_decode(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

if self.data_type == torch.float8_e4m3fn:
q = q.to(torch.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more question. In this case is self.q_data_type different from self.dtype?
It's a little bit confusing... What's your command for launching model with fp8 query?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running with --kv-cache-dtype fp8_e4m3. From the code, self.q_data_type will always be model dtype. The original behavior for FP8 kv will use model dtype query and fp8 kv. This PR makes q/k/v be the same type as --kv-cache-dtype set.

self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype

This is also what other attention backend do, e.g. flashmla

if self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)

Copy link
Collaborator

@b8zhong b8zhong Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elvischenv Sorry for the unrelated question, but: FP8 kv with trtllm mha was supported before this PR, right? Because, I run into some error lately like #12372, I'm not sure if you experienced the same. I may be wrong, but I believe it to be supported before, so maybe SGLang has started passing the wrong params recently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@b8zhong on current main if using --kv-cache-dtype fp8_e4m3, it will use BF16-query FP8-kv kernel. Flashinfer only has BF16-q FP8-kv decode kernel BUT does NOT have BF16-q FP8-kv prefill kernel.

With this PR, --kv-cache-dtype fp8_e4m3 will always use FP8 q, this is good since Flashinfer has FP8-qkv kernel for both prefill and decode kernel.

Copy link
Collaborator

@b8zhong b8zhong Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elvischenv Thanks!! That makes a lot of sense & resolves my issue. Thank you very much.

Flashinfer only has BF16-q FP8-kv decode kernel BUT does NOT have BF16-q FP8-kv prefill kernel.
I did not know this, it's good information.

q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion:
Expand Down Expand Up @@ -567,6 +569,7 @@ def forward_decode(
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
out_dtype=self.q_data_type, # model_runner.dtype
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Expand All @@ -586,6 +589,9 @@ def forward_extend(
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

if self.data_type == torch.float8_e4m3fn:
q = q.to(torch.float8_e4m3fn)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
Expand Down Expand Up @@ -625,6 +631,7 @@ def forward_extend(
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
out_dtype=self.q_data_type, # model_runner.dtype
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Expand Down
Loading