Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ struct AccelEmitter {
/// Return a wrapped view of the LDS buffer tailored for the accelerator
/// load pattern. This is similar to wrapLDSBufferForStore, but while storing
/// in LDS follows a similar pattern among accelerators, loading from LDS
/// is dependent on the type of accelerator we are targeting
/// is dependent on the type of accelerator we are targeting.
/// When useLdsTransposeLoad is true, a special K access pattern
/// is used that is compatible with LDS transpose load on the other operand.
virtual Value
wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
const BlockwiseMatrixParamsAttr &matrixParams,
int64_t blockSize, StringRef dName) const = 0;
int64_t blockSize, StringRef dName,
bool useLdsTransposeLoad = false) const = 0;

/// This functions creates the subtile views that is :
/// 1) gridSubTileView :
Expand All @@ -113,11 +116,14 @@ struct AccelEmitter {
/// 3) threadSubTileView :
/// iter --> ... --> [KPerThread, DPerThread]
/// for each operand tile to be used with gemm accelerators.
/// When otherOperandUsesLdsTranspose is true, a special K access pattern
/// is used that is compatible with LDS transpose load on the other operand.
virtual FailureOr<RegsAsMatrixSubTiles> createAccelGemmOperandTransforms(
OpBuilder &b, Location loc, int64_t kIters,
ArrayRef<int64_t> bidGridLengths, int64_t blockSize,
int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim,
bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false) const = 0;
bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false,
bool otherOperandUsesLdsTranspose = false) const = 0;

/// Validate the accelerator structure
virtual LogicalResult validateAcceleratorProperties() { return success(); };
Expand Down Expand Up @@ -187,14 +193,15 @@ struct MfmaEmitter : public AccelEmitter {

Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
const BlockwiseMatrixParamsAttr &matrixParams,
int64_t blockSize, StringRef dName) const override;
int64_t blockSize, StringRef dName,
bool useLdsTransposeLoad = false) const override;

FailureOr<RegsAsMatrixSubTiles> createAccelGemmOperandTransforms(
OpBuilder &b, Location loc, int64_t kIters,
ArrayRef<int64_t> bidGridLengths, int64_t blockSize,
int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim,
bool rotateDWithK,
bool doSplitKAcrossThreadsFirst = false) const override;
bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false,
bool otherOperandUsesLdsTranspose = false) const override;

FailureOr<RegsAsMatrixSubTiles> computeOutputTransforms(
OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize,
Expand Down Expand Up @@ -240,14 +247,15 @@ struct WmmaEmitter : public AccelEmitter {

Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
const BlockwiseMatrixParamsAttr &matrixParams,
int64_t blockSize, StringRef dName) const override;
int64_t blockSize, StringRef dName,
bool useLdsTransposeLoad = false) const override;

FailureOr<RegsAsMatrixSubTiles> createAccelGemmOperandTransforms(
OpBuilder &b, Location loc, int64_t kIters,
ArrayRef<int64_t> bidGridLengths, int64_t blockSize,
int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim,
bool rotateDWithK,
bool doSplitKAcrossThreadsFirst = false) const override;
bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false,
bool otherOperandUsesLdsTranspose = false) const override;

FailureOr<RegsAsMatrixSubTiles> computeOutputTransforms(
OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize,
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,15 @@ def Rock_BlockwiseMatrixParamsAttr : Rock_Attr<"BlockwiseMatrixParams", []> {
- g: gemm parameter G
- d: gemm parameter D (could be M or N)
- inDPerThread: How many elements of D (M or N) each thread is going to load from memory.
- accelDDim: Accelerator instruction D dimension (for LDS transpose support, typically 16 or 32).
- accelKDim: Accelerator instruction K dimension (for LDS transpose support).
}];
let parameters = (ins "Type":$elementType, "Type":$elementTypeLoad,
"bool":$rotateDWithK, "bool":$swapThreadIterSubDims, "bool":$LDSLayoutDxK,
"bool":$directToLDS, "bool":$splitKAcrossThreadsFirst, "int64_t":$g,
"int64_t":$d, "int64_t":$inDPerThread,
DefaultValuedParameter<"bool", "false">:$ldsTransposeEnabled,
DefaultValuedParameter<"int64_t", "0">:$accelDDim,
DefaultValuedParameter<"int64_t", "0">:$accelKDim);

let assemblyFormat = [{
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,17 @@ struct LDSTransposeDecision {

// Decides whether to enable LDS transpose for operands A and B
// based on architecture, MFMA geometry, kpack constraints, and layout config.
// Parameters:
// - bLoadsFromLDS: Whether operand B actually loads from LDS.
// If false (e.g., Q matrix prefetched to registers), B will be disabled
// for LDS transpose regardless of other constraints.
LDSTransposeDecision decideLDSTransposeForOperands(
const rock::accel::AccelEmitter *accelEmitter, StringRef arch,
Type elementTypeA, Type elementTypeB, bool directToLDS,
const LDSLayoutConfigDim &ldsLayoutConfigA,
const LDSLayoutConfigDim &ldsLayoutConfigB, int64_t mPerBlock,
int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave, int64_t nPerWave,
int64_t kpack, bool doubleBuffering);
int64_t kpack, bool doubleBuffering, bool bLoadsFromLDS = true);

} // namespace mlir::rock::hwtranspose

Expand Down
24 changes: 18 additions & 6 deletions mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ struct BlockwiseGemmAccelRewritePattern
return nullptr;

// Get accelerator dimensions from matrix params and tuning params
// accelDDim = mnPerXdl (for MFMA instructions with blocksMfma=1)
// accelDDim = accelDDim (for MFMA instructions with blocksMfma=1)
// accelKDim = accelKDim from BlockwiseMatrixParamsAttr
int64_t accelDDim = tuningParams.getMnPerXdl();
int64_t accelDDim = matrixParams.getAccelDDim();
int64_t accelKDim = matrixParams.getAccelKDim();

if (accelDDim <= 0 || accelKDim <= 0)
Expand Down Expand Up @@ -507,24 +507,36 @@ struct BlockwiseGemmAccelRewritePattern
// considered a temporary hack until we have a proper way of "searching"
// through different schedules (either heuristically or automatically)

// Determine if the other operand uses LDS transpose load
// This is needed to select the correct K access pattern for regular loads
bool bUsesLdsTranspose = matrixParamsB.getLdsTransposeEnabled();
bool aUsesLdsTranspose = matrixParamsA.getLdsTransposeEnabled();

Value wrappedLDSBufferForLoadA, wrappedLDSBufferForLoadB;
if (loadAFromLDS) {
// When loading A, check if B uses transpose load
wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, op.getMatrixA(), matrixParamsA, op.getBlockSize(), "m");
b, loc, op.getMatrixA(), matrixParamsA, op.getBlockSize(), "m",
/*useLdsTransposeLoad=*/bUsesLdsTranspose);
}
if (loadBFromLDS) {
// When loading B, check if A uses transpose load
wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, op.getMatrixB(), matrixParamsB, op.getBlockSize(), "n");
b, loc, op.getMatrixB(), matrixParamsB, op.getBlockSize(), "n",
/*useLdsTransposeLoad=*/aUsesLdsTranspose);
}
Value wrappedLDSBufferForScaleA, wrappedLDSBufferForScaleB;
if (isScaledGemm) {
// Scaled GEMM (FP4) doesn't support LDS transpose load yet
if (loadAFromLDS) {
wrappedLDSBufferForScaleA = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, op.getScaleA(), matrixParamsA, op.getBlockSize(), "m");
b, loc, op.getScaleA(), matrixParamsA, op.getBlockSize(), "m",
/*useLdsTransposeLoad=*/false);
}
if (loadBFromLDS) {
wrappedLDSBufferForScaleB = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, op.getScaleB(), matrixParamsB, op.getBlockSize(), "n");
b, loc, op.getScaleB(), matrixParamsB, op.getBlockSize(), "n",
/*useLdsTransposeLoad=*/false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,13 @@ class LoweringBlockwiseLoadTileOp final
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
Value tid, StringRef dName, Value ldsView, Value regs, int64_t blockSize,
bool forceUnroll, const BlockwiseMatrixParamsAttr &matrixParams,
LDSTransposeConfigAttr transposeAttr = nullptr) const {
LDSTransposeConfigAttr transposeAttr = nullptr,
bool useLdsTransposeLoad = false) const {

// wrapLDSBufferForLoad is reading a single set of Ks into private memory
// A/B[m/n, 0:kBasePerThread]
Value ldsViewForLoad = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, ldsView, matrixParams, blockSize, dName);
b, loc, ldsView, matrixParams, blockSize, dName, useLdsTransposeLoad);

// We enhance the transformation from wrapLDSBufferForLoad using a builder
// that, given a single index, splits it into "m"("n") and "k" and lets
Expand Down Expand Up @@ -207,9 +208,9 @@ class LoweringBlockwiseLoadTileOp final
LDSTransposeConfigAttr transposeAttr = nullptr;
if (ldsTransposeEnabled) {
// Get accelerator dimensions from matrix params and tuning params
// accelDDim = mnPerXdl (for MFMA instructions with blocksMfma=1)
// accelDDim = AccelDDim (for MFMA instructions with blocksMfma=1)
// accelKDim = accelKDim from BlockwiseMatrixParamsAttr
int64_t accelDDim = tuningParams.getMnPerXdl();
int64_t accelDDim = matrixParams.getAccelDDim();
int64_t accelKDim = matrixParams.getAccelKDim();
assert(accelDDim > 0 && accelKDim > 0 &&
"ldsTranspose=true requires valid accel geometry in params");
Expand Down Expand Up @@ -276,9 +277,14 @@ class LoweringBlockwiseLoadTileOp final

FailureOr<RegsAsMatrixSubTiles> maybeBufferViews;
if (loadType == GemmLoadTileType::BypassLDS) {
// Check if the other operand uses LDS transpose load
bool otherOperandUsesLdsTranspose =
isA ? matrixParamsB.getLdsTransposeEnabled()
: matrixParamsA.getLdsTransposeEnabled();
maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms(
b, loc, kIters, bidGridLengths, blockSize, vecDimInfo.inDPerThread,
dName, isKContiguousDim, false);
dName, isKContiguousDim, false,
/*doSplitKAcrossThreadsFirst=*/false, otherOperandUsesLdsTranspose);
} else {
maybeBufferViews = getLoadRegsAsTileViews(
b, loc, source, dName, bidGridOrder, bidGridLengths, blockSize,
Expand Down Expand Up @@ -338,10 +344,16 @@ class LoweringBlockwiseLoadTileOp final
subview = createSliceOfFirstDim(b, loc, subview, di);
}

// Check if the other operand uses LDS transpose load
bool otherOperandUsesLdsTranspose =
isA ? matrixParamsB.getLdsTransposeEnabled()
: matrixParamsA.getLdsTransposeEnabled();
FailureOr<RegsAsMatrixSubTiles> maybeBufferViews =
accelEmitterPtr->createAccelGemmOperandTransforms(
b, loc, kIters, bidGridLengths, blockSize,
vecDimInfo.inDPerThread, dName, isKContiguousDim, false);
vecDimInfo.inDPerThread, dName, isKContiguousDim, false,
/*doSplitKAcrossThreadsFirst=*/false,
otherOperandUsesLdsTranspose);
if (failed(maybeBufferViews))
return failure();
// InBufferViews provide --> K x D subtile views.
Expand Down Expand Up @@ -452,9 +464,15 @@ class LoweringBlockwiseLoadTileOp final
ldsViewForGemm = viewBufferAs(b, ldsByteBuffer, ldsReadType);
}

// Determine if the other operand uses LDS transpose load
// If we're loading A, check if B uses transpose; if loading B, check
// A
bool useLdsTransposeLoad =
isA ? matrixParamsB.getLdsTransposeEnabled()
: matrixParamsA.getLdsTransposeEnabled();
generateReadLoop(loc, b, accelEmitterPtr, tid, dName, ldsViewForGemm,
destRegisters, blockSize, forceUnroll, matrixParams,
transposeAttr);
transposeAttr, useLdsTransposeLoad);
if (stageLDSReadNew)
rock::YieldOp::create(b, loc);
}
Expand Down
56 changes: 48 additions & 8 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,10 @@ struct GridwiseAttentionAccelRewritePattern
bool directToLDSQ = loadTypeQ == GemmLoadTileType::DirectToLDSDefault ||
loadTypeQ == GemmLoadTileType::DirectToLDSDoubleBuffer;

// Determine if Q loads from LDS (for LDS transpose decision)
// Q bypasses LDS only when loadTypeQ is BypassLDS
bool qLoadsFromLDS = loadTypeQ != GemmLoadTileType::BypassLDS;

// Note that kPerBlock for Gemm1B is mPerBlock of Gemm0 out
// Note that mPerBlock for Gemm1A is mPerBlock of Gemm0 out
// Note that nPerBlock for Gemm1B is nPerBlock of Gemm0 out
Expand Down Expand Up @@ -2412,34 +2416,71 @@ struct GridwiseAttentionAccelRewritePattern
runEarlyExit(rewriter, loc, start, end, splitKV, gemm0MPerBlock,
op.getPrePadG0M(), isCausal, isKVCache);

// create matrix params (LDS transpose not supported for attention)
// LDS Transpose Decision for GEMM0 (K x Q)
// Pass qLoadsFromLDS to disable LDS transpose for Q when it's prefetched
hwtranspose::LDSTransposeDecision ldsDecisionGemm0 =
hwtranspose::decideLDSTransposeForOperands(
accelEmitterPtrGemm0.get(), arch, elemTypeK, elemTypeQ, directToLDS,
ldsLayoutCfgMG0, ldsLayoutCfgNG0, gemm0MPerBlock, gemm0NPerBlock,
gemm0KPerBlock, gemm0TuningParams.getMPerWave(),
gemm0TuningParams.getNPerWave(), gemm0kpack,
/*doubleBuffering=*/false, /*bLoadsFromLDS=*/qLoadsFromLDS);

// create matrix params
BlockwiseMatrixParamsAttr matrixParamsK = BlockwiseMatrixParamsAttr::get(
rewriter.getContext(), elemTypeK, elemTypeKLoad,
ldsLayoutCfgMG0.doRotateWithK, ldsLayoutCfgMG0.doSwapThreadIterSubDims,
ldsLayoutCfgMG0.ldsLayoutDxK, directToLDS,
/*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0M, gemm0InMPerThread,
/*ldsTransposeEnabled=*/false, /*accelKDim=*/0);
/*ldsTransposeEnabled=*/ldsDecisionGemm0.enableA,
/*accelDDim=*/ldsDecisionGemm0.mfmaDDim,
/*accelKDim=*/ldsDecisionGemm0.mfmaKDim);

BlockwiseMatrixParamsAttr matrixParamsQ = BlockwiseMatrixParamsAttr::get(
rewriter.getContext(), elemTypeQ, elemTypeQLoad,
ldsLayoutCfgNG0.doRotateWithK, ldsLayoutCfgNG0.doSwapThreadIterSubDims,
ldsLayoutCfgNG0.ldsLayoutDxK, directToLDSQ,
/*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0N, gemm0InNPerThread,
/*ldsTransposeEnabled=*/false, /*accelKDim=*/0);
/*ldsTransposeEnabled=*/ldsDecisionGemm0.enableB,
/*accelDDim=*/ldsDecisionGemm0.mfmaDDim,
/*accelKDim=*/ldsDecisionGemm0.mfmaKDim);

// LDS Transpose Decision for GEMM1 (V x P)
// Note: LDS transpose for V is ONLY enabled when P is prefetched
// (doBypassLDSSecondGemm = true).
hwtranspose::LDSTransposeDecision ldsDecisionGemm1 =
hwtranspose::decideLDSTransposeForOperands(
accelEmitterPtrGemm1.get(), arch, elemTypeV, elemTypeV, directToLDS,
ldsLayoutCfgMG1, ldsLayoutCfgMG1, gemm1MPerBlock, gemm1NPerBlock,
gemm1KPerBlock, gemm1TuningParams.getMPerWave(),
gemm1TuningParams.getNPerWave(), gemm1kpack,
/*doubleBuffering=*/false,
/*bLoadsFromLDS=*/!doBypassLDSSecondGemm);

// Enable LDS transpose for V only when P is prefetched
bool enableLdsTransposeForV =
doBypassLDSSecondGemm && ldsDecisionGemm1.enableA;

BlockwiseMatrixParamsAttr matrixParamsV = BlockwiseMatrixParamsAttr::get(
rewriter.getContext(), elemTypeV, elemTypeVLoad,
ldsLayoutCfgMG1.doRotateWithK, ldsLayoutCfgMG1.doSwapThreadIterSubDims,
ldsLayoutCfgMG1.ldsLayoutDxK, directToLDS, doBypassLDSSecondGemm,
gemm0G, gemm1M, gemm1InMPerThread,
/*ldsTransposeEnabled=*/false, /*accelKDim=*/0);
/*ldsTransposeEnabled=*/enableLdsTransposeForV,
/*accelDDim=*/ldsDecisionGemm1.mfmaDDim,
/*accelKDim=*/ldsDecisionGemm1.mfmaKDim);

// P matrix (operand B) - when prefetched, uses LDS transpose compatible
// K formula via otherOperandUsesLdsTranspose in
// createAccelGemmOperandTransforms
BlockwiseMatrixParamsAttr matrixParamsKxQ = BlockwiseMatrixParamsAttr::get(
rewriter.getContext(), elemTypeV, elemTypeVLoad, /*rotateDWithK=*/false,
/*swapThreadIterSubDims=*/false, /*LDSLayoutDxK=*/false,
/*directToLDS=*/false, /*splitKAcrossThreadsFirst=*/false, gemm0G,
gemm1N, gemm1InMPerThread,
/*ldsTransposeEnabled=*/false, /*accelKDim=*/0);
/*ldsTransposeEnabled=*/false,
/*accelDDim=*/ldsDecisionGemm1.mfmaDDim,
/*accelKDim=*/ldsDecisionGemm1.mfmaKDim);

// If gemm0K is equal to gemm0KPerBlock that means
// effectively there is no K loop. Therefore, we
Expand Down Expand Up @@ -3256,9 +3297,6 @@ struct GridwiseGemmAccelRewritePattern
directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock,
nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering);

// Note: LDS transpose geometry (accelKDim) is now stored in
// BlockwiseMatrixParamsAttr, not in tuning params

LLVM_DEBUG(llvm::dbgs()
<< "M: " << M << "\n"
<< "N: " << N << "\n"
Expand Down Expand Up @@ -3306,6 +3344,7 @@ struct GridwiseGemmAccelRewritePattern
ldsLayoutConfigA.doSwapThreadIterSubDims, ldsLayoutConfigA.ldsLayoutDxK,
directToLDS, /*splitKAcrossThreadsFirst=*/false, G, M, copyMPerThread,
/*ldsTranspose=*/ldsDecision.enableA,
/*accelDDim=*/ldsDecision.mfmaDDim,
/*accelKDim=*/ldsDecision.mfmaKDim);

BlockwiseMatrixParamsAttr matrixParamsB = BlockwiseMatrixParamsAttr::get(
Expand All @@ -3314,6 +3353,7 @@ struct GridwiseGemmAccelRewritePattern
ldsLayoutConfigB.doSwapThreadIterSubDims, ldsLayoutConfigB.ldsLayoutDxK,
directToLDS, /*splitKAcrossThreadsFirst=*/false, G, N, copyNPerThread,
/*ldsTranspose=*/ldsDecision.enableB,
/*accelDDim=*/ldsDecision.mfmaDDim,
/*accelKDim=*/ldsDecision.mfmaKDim);

// Allocate LDS.
Expand Down
Loading