diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 427dc5c679b..c88bd3b0601 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -529,6 +529,8 @@ def forward_decode( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + if self.data_type == torch.float8_e4m3fn: + q = q.to(torch.float8_e4m3fn) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) # shape conversion: @@ -567,6 +569,7 @@ def forward_decode( window_left=layer.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed sinks=attention_sink, + out_dtype=self.q_data_type, # model_runner.dtype ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -586,6 +589,9 @@ def forward_extend( forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + + if self.data_type == torch.float8_e4m3fn: + q = q.to(torch.float8_e4m3fn) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) @@ -625,6 +631,7 @@ def forward_extend( window_left=layer.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed sinks=attention_sink, + out_dtype=self.q_data_type, # model_runner.dtype ) return o.view(-1, layer.tp_q_head_num * layer.head_dim)