Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6d49d88
Add quick-tune support for Attention.
mirza-halilcevic Dec 11, 2025
651f797
Remove usage of InitParamsAttn and simplify approach.
mirza-halilcevic Dec 12, 2025
4bbb0f8
Merge branch 'develop' into attn-quick-tune
dorde-antic Dec 12, 2025
d89f1df
Remove InitParamsAttn and store serialized perf configs in inc file.
mirza-halilcevic Dec 13, 2025
aa15aee
Merge remote-tracking branch 'origin/develop' into attn-quick-tune
mirza-halilcevic Dec 14, 2025
f3df95a
Merge remote-tracking branch 'origin/attn-quick-tune' into attn-quick…
mirza-halilcevic Dec 14, 2025
fd3bd77
Merge remote-tracking branch 'origin/develop' into attn-quick-tune
mirza-halilcevic Dec 15, 2025
89278d9
Store quick-tune perf configs in serialized format.
mirza-halilcevic Dec 15, 2025
d54d54b
Refactor tuning parameters:
mirza-halilcevic Dec 15, 2025
17eba09
Merge remote-tracking branch 'origin/develop' into attn-quick-tune
mirza-halilcevic Dec 15, 2025
bc2da92
Fix tests.
mirza-halilcevic Dec 16, 2025
fa861d9
Fix formatting.
mirza-halilcevic Dec 16, 2025
e189c11
Merge branch 'develop' into attn-quick-tune
mirza-halilcevic Dec 16, 2025
f2cbc5e
Merge remote-tracking branch 'origin/attn-quick-tune' into attn-quick…
mirza-halilcevic Dec 16, 2025
4fdb68f
Merge remote-tracking branch 'origin/develop' into attn-quick-tune
mirza-halilcevic Dec 16, 2025
690b333
Merge branch 'develop' into attn-quick-tune
dhernandez0 Dec 17, 2025
6cdbb14
Merge remote-tracking branch 'origin/develop' into attn-quick-tune
mirza-halilcevic Dec 27, 2025
ecddeba
Rename Attn to GemmGemm for generic constructs.
mirza-halilcevic Dec 27, 2025
c6a00d6
Improve naming.
mirza-halilcevic Dec 27, 2025
b986135
Update and refactor quickTuningGen.py.
mirza-halilcevic Dec 28, 2025
1cc3688
Merge branch 'develop' into attn-quick-tune
mirza-halilcevic Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 105 additions & 76 deletions mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def Rock_GeneralGemmParamsAttr : Rock_Attr<"GeneralGemmParams", [RockTuningParam
- scheduleVersion: Param to select GEMM schedule
- outputSwizzle: Whether to enable/disable output swizzle or use heuristics
}];

let parameters = (ins
"uint32_t":$blockSize,
"int64_t":$kPerBlock,
Expand All @@ -251,7 +252,8 @@ def Rock_GeneralGemmParamsAttr : Rock_Attr<"GeneralGemmParams", [RockTuningParam

let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
("v3:" + Twine(getBlockSize()) + ","
("v3:"
+ Twine(getBlockSize()) + ","
+ Twine(getMPerBlock()) + ","
+ Twine(getNPerBlock()) + ","
+ Twine(getKPerBlock()) + ","
Expand All @@ -263,74 +265,58 @@ def Rock_GeneralGemmParamsAttr : Rock_Attr<"GeneralGemmParams", [RockTuningParam
.toVector(perfStr);
}
bool getForceUnroll() { return true; }
}];

let assemblyFormat = [{
`<` struct(params) `>`
}];
}

def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterface]> {
let mnemonic = "attn_perf_config";
let description = [{
The perf configs for rock.attention operator.
}];

let parameters = (ins "int64_t":$mPerBlockG0, "int64_t":$mPerBlockG1,
"int64_t":$nPerBlockG0, "int64_t":$kpackPerBlock, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$kpack,
"int64_t":$splitKFactor, "int64_t":$scheduleVersion,
"int64_t":$outputSwizzle, "int64_t":$wavesPerEU, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
("attn:v3:"
+ Twine(getMPerBlockG0()) + ","
+ Twine(getMPerBlockG1()) + ","
+ Twine(getNPerBlockG0()) + ","
+ Twine(getKpackPerBlock()) + ","
+ Twine(getMPerWave()) + ","
+ Twine(getNPerWave()) + ","
+ Twine(getMnPerXdl()) + ","
+ Twine(getKpack()) + ","
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getForceUnroll())).toVector(perfStr);
GeneralGemmParamsAttr withScheduleVersion(int64_t newScheduleVersion) const {
return GeneralGemmParamsAttr::get(
getContext(), getBlockSize(), getKPerBlock(), getMPerBlock(),
getNPerBlock(), getKPerThread(), getMPerThread(), getNPerThread(),
getKpack(), getSplitKFactor(), newScheduleVersion, getOutputSwizzle());
}
AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const {
return AttnPerfConfigAttr::get(
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
getKpackPerBlock(), getMPerWave(), getNPerWave(), getMnPerXdl(), getKpack(),
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getWavesPerEU(), getForceUnroll());
}
}];

let builders = [AttrBuilderWithInferredContext<(
ins "StringAttr":$perfConfigStr, "bool":$isWmma)>];
let builders = [AttrBuilderWithInferredContext<(ins
"StringAttr":$perfConfigStr)>];

let assemblyFormat = [{
`<` struct(params) `>`
}];
}

def Rock_MfmaGemmParamsAttr
: Rock_Attr<"MfmaGemmParams", [RockTuningParamAttrInterface,
RockAccelTuningParamAttrInterface]> {
let mnemonic = "mfma_gemm_params";
def Rock_AccelGemmParamsAttr
: Rock_Attr<"AccelGemmParams", [RockTuningParamAttrInterface,
RockAccelTuningParamAttrInterface]> {
let mnemonic = "accel_gemm_params";
let description = [{
The tuning parameters for an mfma-based matrix multiplication.
The tuning parameters for an mfma or wmma-based matrix multiplication.

- kpackPerBlock: The number of kpack units to process during each main loop
iteration within a workgroup
- mPerBlock: The number of values of m to process in each workgroup
- nPerBlock: The number of values of n to process in each workgroup
- kpack: The number of values of k to pack contiguously into the shared
buffer
- mPerWave: The number of values of m to process in each wavefront
- nPerWave: The number of values of n to process in each wavefront
- mnPerXdl: The size of the m/n dimension in the accelerator (XDL)
instruction
- splitKFactor: Split-k factor for the Split-k GEMM algorithm
- scheduleVersion: Param to select GEMM schedule
- outputSwizzle: Whether to enable/disable output swizzle or use heuristics
- wavesPerEU: Hint to backend compiler for wavefronts per execution unit
- gridGroupSize: Number of chiplets to group together to perform a spatially
local tile
- forceUnroll: Whether to force loop unrolling
}];

let parameters = (ins "int64_t":$kpackPerBlock, "int64_t":$mPerBlock,
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
("v4:" + Twine(getMPerBlock()) + ","
void getPerfConfigStr(SmallVectorImpl<char> &perfStr) {
("v4:"
+ Twine(getMPerBlock()) + ","
+ Twine(getNPerBlock()) + ","
+ Twine(getKpackPerBlock()) + ","
+ Twine(getMPerWave()) + ","
Expand All @@ -343,47 +329,90 @@ def Rock_MfmaGemmParamsAttr
+ Twine(getWavesPerEU()) + ","
+ Twine(getGridGroupSize()) + ","
+ Twine(getForceUnroll()) + ","
+ "1") /* *ThreadCopyMore* */
+ "1") /* *ThreadCopyMore* */
.toVector(perfStr);
}
AccelGemmParamsAttr withScheduleVersion(int64_t newScheduleVersion) const {
return AccelGemmParamsAttr::get(
getContext(), getKpackPerBlock(), getMPerBlock(), getNPerBlock(),
getKpack(), getMPerWave(), getNPerWave(), getMnPerXdl(),
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(),
getWavesPerEU(), getGridGroupSize(), getForceUnroll());
}
}];

let builders = [AttrBuilderWithInferredContext<(
ins "StringAttr":$perfConfigStr, "bool":$isWmma)>];

let assemblyFormat = [{
`<` struct(params) `>`
}];
}

def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrInterface, RockAccelTuningParamAttrInterface]> {
let mnemonic = "wmma_gemm_params";
def Rock_GemmGemmParamsAttr
: Rock_Attr<"GemmGemmParams", [RockTuningParamAttrInterface]> {
let mnemonic = "gemm_gemm_params";
let description = [{
The tuning parameters for an wmma-based matrix multiplication.
The tuning parameters for an mfma or wmma-based gemm+gemm operation
(e.g., attention).

- mPerBlockG0: The number of values of m to process in each workgroup for
the first GEMM
- mPerBlockG1: The number of values of m to process in each workgroup for
the second GEMM
- nPerBlockG0: The number of values of n to process in each workgroup for
the first GEMM
- kpackPerBlock: The number of kpack units to process during each main loop
iteration within a workgroup
- mPerWave: The number of values of m to process in each wavefront
- nPerWave: The number of values of n to process in each wavefront
- mnPerXdl: The size of the m/n dimension in the accelerator (XDL)
instruction
- kpack: The number of values of k to pack contiguously into the shared
buffer
- splitKFactor: Split-k factor for the Split-k GEMM algorithm
- scheduleVersion: Param to select GEMM schedule
- outputSwizzle: Whether to enable/disable output swizzle or use heuristics
- wavesPerEU: Hint to backend compiler for wavefronts per execution unit
- forceUnroll: Whether to force loop unrolling
}];
let parameters = (ins "int64_t":$kpackPerBlock, "int64_t":$mPerBlock,
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);

let parameters = (ins "int64_t":$mPerBlockG0, "int64_t":$mPerBlockG1,
"int64_t":$nPerBlockG0, "int64_t":$kpackPerBlock, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$kpack,
"int64_t":$splitKFactor, "int64_t":$scheduleVersion,
"int64_t":$outputSwizzle, "int64_t":$wavesPerEU, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(SmallVectorImpl<char> &perfStr) {
("v4:" + Twine(getMPerBlock()) + ","
+ Twine(getNPerBlock()) + ","
+ Twine(getKpackPerBlock()) + ","
+ Twine(getMPerWave()) + ","
+ Twine(getNPerWave()) + ","
+ Twine(getMnPerXdl()) + ","
+ Twine(getKpack()) + ","
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getGridGroupSize()) + ","
+ Twine(getForceUnroll()) + ","
+ "1") /* *ThreadCopyMore* */
.toVector(perfStr);
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
("attn:v3:"
+ Twine(getMPerBlockG0()) + ","
+ Twine(getMPerBlockG1()) + ","
+ Twine(getNPerBlockG0()) + ","
+ Twine(getKpackPerBlock()) + ","
+ Twine(getMPerWave()) + ","
+ Twine(getNPerWave()) + ","
+ Twine(getMnPerXdl()) + ","
+ Twine(getKpack()) + ","
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getForceUnroll()))
.toVector(perfStr);
}
GemmGemmParamsAttr withScheduleVersion(int64_t newScheduleVersion) const {
return GemmGemmParamsAttr::get(
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
getKpackPerBlock(), getMPerWave(), getNPerWave(), getMnPerXdl(),
getKpack(), getSplitKFactor(), newScheduleVersion, getOutputSwizzle(),
getWavesPerEU(), getForceUnroll());
}
}];

let builders = [AttrBuilderWithInferredContext<(
ins "StringAttr":$perfConfigStr, "bool":$isWmma)>];

let assemblyFormat = [{
`<` struct(params) `>`
}];
Expand Down
62 changes: 62 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmGemmParams.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- GridwiseGemmGemmParams.h - MLIR tuning parameter generation --------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines MLIR tuning parameter generation for gemm+gemm (attn) ops
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ROCK_GRIDWISE_GEMM_GEMM_PARAMS_H
#define MLIR_DIALECT_ROCK_GRIDWISE_GEMM_GEMM_PARAMS_H

#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
#include "mlir/Dialect/Rock/Tuning/ParamLookupTable.h"

namespace mlir {
namespace rock {

class PopulateParamsGemmGemm {
public:
static std::vector<GemmGemmParamsAttr>
getTuningParameters(OpBuilder &b, RockGemmGemmWrapperInterface op);

static LogicalResult paramsProbablyValid(OpBuilder &b,
RockGemmGemmWrapperInterface op,
GemmGemmParamsAttr params);

static FailureOr<std::pair<AccelGemmParamsAttr, AccelGemmParamsAttr>>
getAccelGemmParams(OpBuilder &b, RockGemmGemmWrapperInterface op,
GemmGemmParamsAttr params);

protected:
static GemmGemmParamsAttr
deserializePerfConfig(OpBuilder &b, RockGemmGemmWrapperInterface op,
StringRef config);

static std::vector<GemmGemmParamsAttr>
deserializePerfConfigs(OpBuilder &b, RockGemmGemmWrapperInterface op,
ArrayRef<StringRef> configs);

static AccelGemmParamsAttr getGemm0Params(OpBuilder &b,
GemmGemmParamsAttr params);

static AccelGemmParamsAttr getGemm1Params(OpBuilder &b,
GemmGemmParamsAttr params);

private:
#define GemmGemm_DECLARATIONS_GEN
#include "mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc"
#undef GemmGemm_DECLARATIONS_GEN

friend class ParamLookupTable<GemmGemmParamsAttr>;
};

} // namespace rock
} // namespace mlir

#endif // MLIR_DIALECT_ROCK_GRIDWISE_GEMM_GEMM_PARAMS_H
Loading