Skip to content

Conversation

@stefankoncarevic
Copy link
Contributor

@stefankoncarevic stefankoncarevic commented Dec 23, 2025

Motivation

Enable lds transpose load for attention kernel, extended wave configs, regular + transpose load
Resolves:
https://github.com/ROCm/rocMLIR-internal/issues/2207
https://github.com/ROCm/rocMLIR-internal/issues/2175
https://github.com/ROCm/rocMLIR-internal/issues/2155

Technical Details

1. Adds LDS transpose load optimization support for the attention kernel's GEMM operations.

GEMM0 (K × Q):

  • Added decideLDSTransposeForOperands() call for GEMM0
  • K matrix (operand A): Can use LDS transpose when directToLDS is enabled
  • Q matrix (operand B): Can use LDS transpose only when NOT prefetched to registers. Added bLoadsFromLDS parameter to decideLDSTransposeForOperands() to handle the prefetch case correctly.

GEMM1 (V × P):

  • Added decideLDSTransposeForOperands() call for GEMM1
  • V matrix (operand A): Uses LDS transpose when P is prefetched
    (doBypassLDSSecondGemm = true) and directToLDS is enabled
  • P matrix (operand B): Never uses LDS transpose (comes from registers).

API changes:

  • Extended decideLDSTransposeForOperands() with optional bLoadsFromLDS parameter (default=true). When false, operand B is immediately marked as NOT USABLE for LDS transpose, regardless of other constraints.

BlockwiseMatrixParamsAttr changes:

  • Replaced mnPerXdl with accelDDim parameter. The MFMA instruction geometry (e.g., 16x16, 16x32, 32x16) requires the D dimension (M or N) to correctly configure the LDS transpose load. mnPerXdl in some cases can be different from accelDDim value, that produce the function isValidMfmaGeometry have wrong decision.

2. Extend LDS transpose load support for 8 and 16 wave configurations

Changes:

  • Extended numWaves limit from 4 to 16 in decideLDSTransposeForOperands()
  • Added wave grid layout computation for 8 waves:
    • 2×4, 4×2 (preferred balanced layouts)
    • 1×8, 8×1 (fallback layouts)
  • Added wave grid layout computation for 16 waves:
    • 4×4 (preferred balanced layout)
    • 2×8, 8×2 (semi-balanced layouts)
    • 1×16, 16×1 (fallback layouts)

Hybrid Load Support (Regular + LDS Transpose)

3. Fix K-access formula for hybrid LDS transpose load scenario

When one operand uses regular load and the other uses LDS transpose load, the regular load must use a K-access pattern compatible with the transpose load's data layout.

Changes:

  • Added useLdsTransposeLoad parameter to wrapLDSBufferForLoad() in AccelEmitter.cpp. This flag indicates when the current operand should use a K-access pattern compatible with LDS transpose load on the other operand.

  • New K-access formula for hybrid scenario: When useLdsTransposeLoad && kVec >= kBase:

    • Split blk_id into blk_d (D dimension) and blk_k (K dimension) based on MFMA geometry
    • Split k_vec into k_mfma (which MFMA within kpack) and k_base (element within kBase)
    • K access: k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k * kBase + k_base
  • When kVec < kBase (e.g., kpack=1, kBase=4), we use previously pattern load pattern.

  • Parameter propagation: Updated BlockwiseGemmToThreadwise.cpp and BlockwiseLoadTileToThreadwise.cpp to pass the LDS transpose state through the lowering pipeline.

  • This ensures correct data alignment between regular and transpose loads for all supported MFMA geometries (16×16, 32×8, 16×32, 32×16) and various kpack/kpackPerBlock configurations.

Test Plan

Updated tests:

  • lds_transpose_attributes_toblockwise.mlir: Changed CHECK-NOT to
    CHECK for 8 and 16 wave tests, confirming LDS transpose is now
    enabled for these configurations
  • PrLdsTransposeLoad.toml: Added e2e test cases for 8-wave (4×2, 1×8)
    and 16-wave (8×2, 1×16) grid configurations

New tests for hybrid load scenarios with various kpack sizes:

  • lds_transpose_A_only:

    • Added tests for hybrid scenario where A uses LDS transpose load and B uses regular load with various kpack values (4, 16, 32) to verify the new K-access formula:
  • lds_transpose_B_only:

    • Added tests for hybrid scenario where B uses LDS transpose load and A uses regular load with various kpack values (16, 32) to ensure correct data alignment:
  • These tests cover various kpack (4, 8, 16, 32) and kpackPerBlock (2, 4, 8, 16, 32) combinations across different MFMA geometries (16×16, 32×8, 16×32, 32×16) to validate the hybrid K-access formula.

Test Result

Submission Checklist

This commit adds LDS transpose load optimization support for the
attention kernel's GEMM operations.

GEMM0 (K × Q):
- Added decideLDSTransposeForOperands() call for GEMM0
- K matrix (operand A): Can use LDS transpose when directToLDS is enabled
- Q matrix (operand B): Can use LDS transpose only when NOT prefetched
  to registers. Added bLoadsFromLDS parameter to decideLDSTransposeForOperands()
  to handle the prefetch case correctly.

GEMM1 (V × P):
- Added decideLDSTransposeForOperands() call for GEMM1
- V matrix (operand A): Can use LDS transpose when directToLDS is enabled
- P matrix (operand B): Never uses LDS transpose since it comes from
  registers (softmax output), not from global memory. directToLDS is
  always false for P.

API changes:
- Extended decideLDSTransposeForOperands() with optional bLoadsFromLDS
  parameter (default=true). When false, operand B is immediately marked
  as NOT USABLE for LDS transpose, regardless of other constraints.

BlockwiseMatrixParamsAttr changes:
- Replaced mnPerXdl with accelDDim and accelKDim parameters. The MFMA
  instruction geometry (e.g., 16x16, 16x32, 32x16) requires both the
  D dimension (M or N) and K dimension to correctly configure the
  LDS transpose load. mnPerXdl only captured one dimension, which was
  insufficient for determining the correct hardware transpose behavior
  and offset calculations.
@stefankoncarevic stefankoncarevic self-assigned this Dec 23, 2025
@stefankoncarevic stefankoncarevic changed the title Add LDS transpose load support for attention kernel [WIP] Add LDS transpose load support for attention kernel Dec 23, 2025
This commit extends the LDS transpose load optimization to support
workgroups with 8 waves (blockSize=512) and 16 waves (blockSize=1024).

Previously, the optimization was limited to 1-4 waves only. This
restriction has been lifted to enable LDS transpose load for larger
workgroup sizes commonly used in high-performance GEMM configurations.

Changes:
- Extended numWaves limit from 4 to 16 in decideLDSTransposeForOperands()
- Added wave grid layout computation for 8 waves:
  - 2×4, 4×2 (preferred balanced layouts)
  - 1×8, 8×1 (fallback layouts)
- Added wave grid layout computation for 16 waves:
  - 4×4 (preferred balanced layout)
  - 2×8, 8×2 (semi-balanced layouts)
  - 1×16, 16×1 (fallback layouts)

Updated tests:
- lds_transpose_attributes_toblockwise.mlir: Changed CHECK-NOT to
  CHECK for 8 and 16 wave tests, confirming LDS transpose is now
  enabled for these configurations
- PrLdsTransposeLoad.toml: Added e2e test cases for 8-wave (4×2, 1×8)
  and 16-wave (8×2, 1×16) grid configurations
When one operand uses regular load and the other uses LDS transpose
load, the regular load must use a compatible K-access pattern.

The new formula is only applied when:
- useLdsTransposeLoad is true (hybrid scenario)
- kVec >= kBase (enough elements to decompose)

This ensures correct data alignment between regular and transpose
loads for MFMA operations, and prevents assertion failures when
kpack < kBase.

Changes:
- Add useLdsTransposeLoad parameter to wrapLDSBufferForLoad
- Implement hybrid K-access formula with blk_d/blk_k split
- Pass LDS transpose state from BlockwiseGemmToThreadwise
- Update tests in PrLdsTransposeLoad.toml
@stefankoncarevic stefankoncarevic changed the title [WIP] Add LDS transpose load support for attention kernel LDS transpose load: attention, extended wave configs, regular + transpose load Dec 26, 2025
Copy link
Contributor

@pabloantoniom pabloantoniom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description of the PR is super helpful but I would suggest splitting the "Technical Details" section into multiple smaller sections. For example, instead of:

Extend LDS transpose load support for 8 and 16 wave configurations

Changes:
...

Fix K-access formula for hybrid LDS transpose load scenario

Changes:
...

you could have:

1. Extend LDS transpose load support for 8 and 16 wave configurations

Changes:

2. Fix K-access formula for hybrid LDS transpose load scenario

Changes:
...

const BlockwiseMatrixParamsAttr &matrixParams, int64_t blockSize,
StringRef dName) const {
StringRef dName, bool useLdsTransposeLoad) const {
// Note: WMMA does not support LDS transpose load, so the parameter is unused.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to create a ticket to give support for gfx1250

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we already have ticket for enabling this inside epic label with other tickets related for gfx1250.

// Use LDS transpose compatible K formula only when:
// 1. Other operand uses LDS transpose load (hybrid scenario)
// 2. kVec >= kBase (enough elements per load to decompose)
int64_t kBase = accelEmitterParams.kBase;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we move this together with the other local variable declarations (ie line 442 or 450?)

if (useLdsTransposeLoad && kVec >= kBase) {
// K access pattern must match the transpose load's pattern.
// For double-rate MFMA, properly distribute K across threads
MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have mfmaAttr right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will deleted this one. Thanks


// Determine if Q loads from LDS (for LDS transpose decision)
// Q bypasses LDS only when prefetch is active
bool qLoadsFromLDS = !prefetchQTile;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do bool qLoadsFromLDS = loadTypeQ == GemmLoadTileType::BypassLDS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll make change.


LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Decision for operand B: "
<< (decB.usable ? "USABLE" : "NOT USABLE") << "\n");
// If B doesn't load from LDS (e.g., prefetched Q matrix), it can't use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why don't we have for A as well? Does A always loads from LDS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, A is always in LDS. In the current attention implementation:

  • Q matrix (operand B in Gemm0) can bypass LDS when prefetch is enabled
  • K matrix (operand A in Gemm0) always loads from LDS

// ensures compatibility when mixing regular load with transpose load.
bool anyUsable = decA.usable || decB.usable;

if (bothUsable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we have both A and B use transpose load anymore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I just saw "Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false)" in the E2E test. Would that work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both A and B CAN use transpose load simultaneously, but it depends on the LDS layout requirements:

LDS Layout Requirements:

  • Matrix A requires column-major layout in LDS (K×M) for transpose load to work
  • Matrix B requires row-major layout in LDS (K×N) for transpose load to work

For GEMM:
This corresponds to transA=true, transB=false configuration:

  • When transA=true: A is stored column-major in LDS → transpose load usable for A
  • When transB=false: B is stored row-major in LDS → transpose load usable for B

The E2E test "Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false)" confirms this combination works correctly.

For Attention:

  • Q matrix (Gemm0 operand B): Layout depends on how Q tiles are loaded. When Q loads from LDS, it uses K×N layout (row-major), making transpose load usable.
  • K matrix (Gemm0 operand A): K is stored in LDS with K×M layout (column-major for the transposed access pattern), making transpose load usable.
  • The attention kernel naturally has layouts compatible with transpose load for both operands when the tile dimensions align with MFMA requirements.

The anyUsable check is for early exit optimization - when at least one operand can benefit from transpose load, we proceed with the transformation. When both are usable (as in Suite 1), both will use it.

The E2E tests cover:

  • Suite 1: Both A and B use transpose load
  • Suite 2: Only A uses transpose load
  • Suite 3: Only B uses transpose load

All combinations are supported and tested.

// TODO: support 32 waves for WMMA
int64_t numWaves = (mPerBlock * nPerBlock) / (mPerWave * nPerWave);
if (numWaves > 4) {
if (numWaves > 16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to create a new ticket to make this work for 32 (gfx1250). Is it really that hard to make it in a different ticket or can we do it here? If it's just a couple of lines I mean. I can test it on the emulator if you want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add the change here (it would be a few lines), but I think it's better to handle this as part of the full gfx1250 enablement. There are likely other adjustments needed for gfx1250 that should be tested together.

I'll create a follow-up ticket to extend LDS transpose load support for gfx1250 with numWaves up to 32. That way we can properly test all the gfx1250-specific changes together.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want some E2E test for attention as well, not only gemm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add E2E tests for attention with LDS transpose load. However, this functionality is already well covered by existing direct-to-LDS attention tests.

Currently, LDS transpose load is always enabled when direct-to-LDS is active, and the direct-to-LDS test suite includes extensive attention test coverage with various:

  • Head sizes and sequence lengths
  • Different MFMA configurations
  • Q/K/V matrix layouts

These existing tests implicitly validate LDS transpose load behavior for attention.

In the future, we may evaluate whether LDS transpose load should remain always-on for direct-to-LDS cases, or be selectively enabled based on performance characteristics. Additional performance testing will guide this decision.

For now, I'll add a few explicit attention tests to the nightly suite to have dedicated coverage, but the core functionality is already being tested through the direct-to-LDS attention tests.

// CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xbf16, #gpu.address_space<workgroup>> -> vector<4xbf16>
%v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space<workgroup>> -> vector<4xbf16>
return %v : vector<4xbf16>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add also a test on other arch, e.g., gfx942, to check that we generate an error message saying that transpose load is not supported for that arch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a negative test case that verifies for gfx942

- Move kBase variable declaration earlier in wrapLDSBufferForLoad
- Remove duplicate MfmaInsnAttr declaration, reuse existing mfmaAttr
- Add negative test for LDS transpose load on unsupported arch (gfx942)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants