We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1819e09 commit f24fddcCopy full SHA for f24fddc
python/triton_kernels/triton_kernels/matmul_details/opt_flags.py
@@ -210,6 +210,9 @@ def make_default_opt_flags_nvidia(
210
block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128))
211
else:
212
block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64))
213
+ if block_m == 64 and precision_config.out_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
214
+ # when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
215
+ block_m = 128
216
217
block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
218
# block n
0 commit comments