Skip to content

Commit f24fddc

Browse files
committed
use block 128
1 parent 1819e09 commit f24fddc

File tree

1 file changed

+3
-0
lines changed
  • python/triton_kernels/triton_kernels/matmul_details

1 file changed

+3
-0
lines changed

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def make_default_opt_flags_nvidia(
210210
block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128))
211211
else:
212212
block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64))
213+
if block_m == 64 and precision_config.out_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
214+
# when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
215+
block_m = 128
213216
else:
214217
block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
215218
# block n

0 commit comments

Comments
 (0)