Skip to content

Conversation

@elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Oct 29, 2025

Motivation

TRTLLM_MHA already supports FP8-qkv BF16-out attention kernel. This will achieve better performance compared to original BF16-q FP8-kv kernel.
#9782

Modifications

This PR convert query to FP8 dtype if kv cache type is also FP8.

Accuracy Tests

lm_eval --model local-completions --tasks gsm8k --model_args model=openai/gpt-oss-120b,base_url=http://127.0.0.1:18000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8682|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.6091|±  |0.0190|

main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8394|±  |0.0143|
|     |       |strict-match    |     5|exact_match|↑  |0.6318|±  |0.0188|

Benchmarking and Profiling

python3 -m sglang.bench_serving --model openai/gpt-oss-120b --host 127.0.0.1 --port 18000 --backend sglang-oai --dataset-name random --random-range-ratio 1 --random-input-len 1024 --random-output-len 1024 --max-concurrency 512 --num-prompts 2560

PR(5% perf improvement):

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    inf
Max request concurrency:                 512
Successful requests:                     2560
Benchmark duration (s):                  153.18
Total input tokens:                      2621440
Total input text tokens:                 2621440
Total input vision tokens:               0
Total generated tokens:                  2621440
Total generated tokens (retokenized):    2552374
Request throughput (req/s):              16.71
Input token throughput (tok/s):          17113.94
Output token throughput (tok/s):         17113.94
Total token throughput (tok/s):          34227.89
Concurrency:                             510.27
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   30531.42
Median E2E Latency (ms):                 30537.68
---------------Time to First Token----------------
Mean TTFT (ms):                          2720.10
Median TTFT (ms):                        2655.91
P99 TTFT (ms):                           4974.16
---------------Inter-Token Latency----------------
Mean ITL (ms):                           27.24
Median ITL (ms):                         24.75
P95 ITL (ms):                            26.51
P99 ITL (ms):                            153.28
Max ITL (ms):                            4132.87
==================================================

main:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    inf
Max request concurrency:                 512
Successful requests:                     2560
Benchmark duration (s):                  161.57
Total input tokens:                      2621440
Total input text tokens:                 2621440
Total input vision tokens:               0
Total generated tokens:                  2621440
Total generated tokens (retokenized):    2539762
Request throughput (req/s):              15.84
Input token throughput (tok/s):          16225.25
Output token throughput (tok/s):         16225.25
Total token throughput (tok/s):          32450.49
Concurrency:                             510.02
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   32188.29
Median E2E Latency (ms):                 32219.19
---------------Time to First Token----------------
Mean TTFT (ms):                          2632.67
Median TTFT (ms):                        2608.20
P99 TTFT (ms):                           4823.79
---------------Inter-Token Latency----------------
Mean ITL (ms):                           28.98
Median ITL (ms):                         26.49
P95 ITL (ms):                            28.89
P99 ITL (ms):                            154.62
Max ITL (ms):                            4009.70
==================================================

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@b8zhong b8zhong added the run-ci label Oct 29, 2025
Copy link
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

)

if self.data_type == torch.float8_e4m3fn:
q = q.to(torch.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more question. In this case is self.q_data_type different from self.dtype?
It's a little bit confusing... What's your command for launching model with fp8 query?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running with --kv-cache-dtype fp8_e4m3. From the code, self.q_data_type will always be model dtype. The original behavior for FP8 kv will use model dtype query and fp8 kv. This PR makes q/k/v be the same type as --kv-cache-dtype set.

self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype

This is also what other attention backend do, e.g. flashmla

if self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)

Copy link
Collaborator

@b8zhong b8zhong Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elvischenv Sorry for the unrelated question, but: FP8 kv with trtllm mha was supported before this PR, right? Because, I run into some error lately like #12372, I'm not sure if you experienced the same. I may be wrong, but I believe it to be supported before, so maybe SGLang has started passing the wrong params recently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@b8zhong on current main if using --kv-cache-dtype fp8_e4m3, it will use BF16-query FP8-kv kernel. Flashinfer only has BF16-q FP8-kv decode kernel BUT does NOT have BF16-q FP8-kv prefill kernel.

With this PR, --kv-cache-dtype fp8_e4m3 will always use FP8 q, this is good since Flashinfer has FP8-qkv kernel for both prefill and decode kernel.

Copy link
Collaborator

@b8zhong b8zhong Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elvischenv Thanks!! That makes a lot of sense & resolves my issue. Thank you very much.

Flashinfer only has BF16-q FP8-kv decode kernel BUT does NOT have BF16-q FP8-kv prefill kernel.
I did not know this, it's good information.

@Fridge003 Fridge003 merged commit 069e490 into sgl-project:main Oct 31, 2025
53 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants