Skip to content

Conversation

@ppetrovi-amd
Copy link

Fix Flash Attention precision loss on RDNA3 (gfx11xx)

Problem

Flash Attention on RDNA3 (gfx1100/Navi31) GPUs was failing precision tests with error ratios exceeding 2.0x compared to PyTorch reference. The issue manifested in test_flash_attn_triton_amd.py tests.

Root Cause

The WMMA instructions on RDNA3 accumulate in FP32, but the intermediate GEMM0 output buffer was being allocated as FP16 (matching the input type), and preSoftmaxBody immediately extends back to FP32 via arith.extf for softmax. This caused an unnecessary FP32→FP16→FP32 roundtrip. This roundtrip loses precision.

Solution

In GridwiseGemmToBlockwise.cpp, detect when the preSoftmaxBody contains an immediate arith.extf operation on the GEMM0 output. When detected:

  1. Allocate gemm0OutBuffer in FP32 (the softmax precision type) instead of FP16
  2. Update linalg.generic block argument types to match the new buffer type
  3. Replace now redundant arith.extf ops (FP32→FP32) with their input
  4. Extend other operands in arithmetic ops to maintain type consistency

This keeps the WMMA FP32 result in FP32 until after softmax, eliminating the precision loss.

NOTE: THIS IS STILL JUST A DRAFT TO CONFIRM THAT IT FIXES A PROBLEM.

Keep GEMM0 output in FP32 when preSoftmaxBody contains arith.extf
to avoid unnecessary FP32->FP16->FP32 roundtrip truncation.

The WMMA accumulator is FP32, and storing to FP16 before the
softmax (which needs FP32) caused significant precision loss.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants