[BugFix] Fix mask_id bug in block_sparse_sage2_attn_cuda for sm90 #54
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
On sm90 architecture, sparge_mask_convert() repeats mask along query_idx dimension. This fix corrects the mask_id indexing in block_sparse_sage2_attn_cuda().
Before fix, the query's mask_id in block_sparse_sage2_attn_cuda() is wrong ,error occors:
sp-radial-attention/radial_attn/attn_mask.py", line 363, in RadialAttention return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "sp-radial-attention/radial_attn/attn_mask.py", line 276, in SpargeSageAttnBackend k=key[:pre_defined_mask[0].sum(), :, :], ~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: an illegal memory access was encountered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile withTORCH_USE_CUDA_DSAto enable device-side assertions.After Fix , HunyuanVideo + Radial + SageAttention works successfully.
hunyuan_radial_sage_sp4.mp4
But Wan2.1 (no text tokens in attention)produces higher quality results:
wan_radial_sp4_sage.mp4
Limitations