- 
                Notifications
    You must be signed in to change notification settings 
- Fork 3.2k
feat: support trtllm_mha FP8 query attention kernel #12307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support trtllm_mha FP8 query attention kernel #12307
Conversation
| Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| ) | ||
|  | ||
| if self.data_type == torch.float8_e4m3fn: | ||
| q = q.to(torch.float8_e4m3fn) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more question. In this case is self.q_data_type different from self.dtype?
It's a little bit confusing... What's your command for launching model with fp8 query?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running with --kv-cache-dtype fp8_e4m3. From the code, self.q_data_type will always be model dtype. The original behavior for FP8 kv will use model dtype query and fp8 kv. This PR makes q/k/v be the same type as --kv-cache-dtype set.
sglang/python/sglang/srt/layers/attention/trtllm_mha_backend.py
Lines 75 to 76 in 7ed8ba0
| self.data_type = model_runner.kv_cache_dtype | |
| self.q_data_type = model_runner.dtype | 
This is also what other attention backend do, e.g. flashmla
sglang/python/sglang/srt/layers/attention/flashmla_backend.py
Lines 356 to 357 in 7ed8ba0
| if self.data_type == torch.float8_e4m3fn: | |
| reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@elvischenv Sorry for the unrelated question, but: FP8 kv with trtllm mha was supported before this PR, right? Because, I run into some error lately like #12372, I'm not sure if you experienced the same. I may be wrong, but I believe it to be supported before, so maybe SGLang has started passing the wrong params recently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@b8zhong on current main if using --kv-cache-dtype fp8_e4m3, it will use BF16-query FP8-kv kernel. Flashinfer only has BF16-q FP8-kv decode kernel BUT does NOT have BF16-q FP8-kv prefill kernel.
With this PR, --kv-cache-dtype fp8_e4m3 will always use FP8 q, this is good since Flashinfer has FP8-qkv kernel for both prefill and decode kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@elvischenv Thanks!! That makes a lot of sense & resolves my issue. Thank you very much.
Flashinfer only has BF16-q FP8-kv decode kernel BUT does NOT have BF16-q FP8-kv prefill kernel.
I did not know this, it's good information.
Motivation
TRTLLM_MHA already supports FP8-qkv BF16-out attention kernel. This will achieve better performance compared to original BF16-q FP8-kv kernel.
#9782
Modifications
This PR convert query to FP8 dtype if kv cache type is also FP8.
Accuracy Tests
lm_eval --model local-completions --tasks gsm8k --model_args model=openai/gpt-oss-120b,base_url=http://127.0.0.1:18000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5PR:
main:
Benchmarking and Profiling
python3 -m sglang.bench_serving --model openai/gpt-oss-120b --host 127.0.0.1 --port 18000 --backend sglang-oai --dataset-name random --random-range-ratio 1 --random-input-len 1024 --random-output-len 1024 --max-concurrency 512 --num-prompts 2560PR(5% perf improvement):
main:
Checklist