Skip to content

Commit abeeefe

Browse files
committed
WIP
1 parent 3f1b7d9 commit abeeefe

File tree

3 files changed

+121
-5
lines changed

3 files changed

+121
-5
lines changed

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def Exp2Op : UnaryWaveOp<"exp2"> {
110110

111111
def MmaOp : WaveOp<"mma",
112112
[DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
113+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
113114
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface,
114115
["initializeIndexExprsForward", "initializeIndexExprsBackward"]>]>,
115116
WaveArithmeticOpDoc {
@@ -140,6 +141,7 @@ def IterateOp : Op<WaveDialect, "iterate", [
140141
AttrSizedOperandSegments,
141142
DeclareOpInterfaceMethods<RegionBranchOpInterface,
142143
["areTypesCompatible", "getEntrySuccessorOperands"]>,
144+
NoOpElementsPerThreadOpTrait,
143145
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
144146
let summary = "Executes the body repeatedly";
145147
let description = [{
@@ -249,7 +251,7 @@ def AllocateOp : WaveOp<"allocate"> {
249251
let hasVerifier = 1;
250252
}
251253

252-
def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait, CompatibleOperandsAndResultsOpTrait]> {
254+
def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait, CompatibleOperandsAndResultsOpTrait, NoOpElementsPerThreadOpTrait]> {
253255
let summary = "Extracts a subvector from an n-D tensor";
254256
let description = [{
255257
Extracts an n-D subvector from an n-D tensor using k-D offset, size, and
@@ -274,7 +276,7 @@ def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, Identity
274276

275277
def ReadOp : WaveOp<"read", [
276278
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
277-
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
279+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
278280
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
279281
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
280282
let summary = "Reads from memory";
@@ -328,7 +330,7 @@ def RegisterOp : WaveOp<"register", [
328330

329331
def WriteOp : WaveOp<"write", [
330332
WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait,
331-
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
333+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
332334
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
333335
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
334336
let summary = "Writes into memory";

water/lib/Dialect/Wave/IR/WaveOps.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,57 @@ llvm::FailureOr<mlir::ChangeResult> wave::MmaOp::propagateBackward(
292292
"result", "accumulator", errs);
293293
}
294294

295+
// Get the MMA result elements per thread for a given MMA kind
296+
static unsigned getMmaResultElementsPerThreadForMmaKind(mlir::MLIRContext *context, wave::WaveMmaKind kind) {
297+
// Get the MMA specification (M, N, K dimensions and element types)
298+
wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec(context, kind);
299+
300+
// Elements per thread = (M × N) / threads_per_wave
301+
// AMD GPU waves have 64 threads
302+
constexpr unsigned threadsPerWave = 64;
303+
unsigned totalElements = spec.m * spec.n;
304+
return totalElements / threadsPerWave;
305+
}
306+
307+
llvm::FailureOr<mlir::ChangeResult> wave::MmaOp::propagateElementsPerThreadForward(
308+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
309+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
310+
llvm::raw_ostream &errs) {
311+
// For MMA, the result elements per thread is determined by the MMA kind, not the operands
312+
unsigned expectedElementsPerThread = getMmaResultElementsPerThreadForMmaKind(getContext(), getKind());
313+
wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread);
314+
315+
// Propagate to result
316+
auto joined = wave::ElementsPerThreadLatticeValue::join(expectedResult, resultElements[0]);
317+
if (joined.isTop() && !expectedResult.isTop() && !resultElements[0].isTop()) {
318+
errs << "mismatched elements per thread for MMA result: expected " << expectedElementsPerThread
319+
<< " elements per thread for MMA kind " << getKind() << " but got (";
320+
resultElements[0].print(errs);
321+
errs << ")";
322+
return mlir::failure();
323+
}
324+
325+
if (joined != resultElements[0]) {
326+
resultElements[0] = joined;
327+
return mlir::ChangeResult::Change;
328+
}
329+
330+
return mlir::ChangeResult::NoChange;
331+
}
332+
333+
llvm::FailureOr<mlir::ChangeResult> wave::MmaOp::propagateElementsPerThreadBackward(
334+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
335+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
336+
llvm::raw_ostream &errs) {
337+
// For MMA, operands and result elements per thread may be different
338+
// The result is determined by the MMA kind, operands can have their own values
339+
// We don't propagate backwards for MMA since operands and result have independent constraints
340+
(void)operandElements; // Avoid unused parameter warning
341+
(void)resultElements;
342+
(void)errs;
343+
return mlir::ChangeResult::NoChange;
344+
}
345+
295346
// Set the value of `lattice` to `newLattice` and return whether a change
296347
// happened. Note that this does NOT verify whether the lattice change goes into
297348
// the direction of top or bottom.
@@ -1331,6 +1382,41 @@ LogicalResult ReadOp::verify() {
13311382
bounds.getMapping());
13321383
}
13331384

1385+
llvm::FailureOr<mlir::ChangeResult> wave::ReadOp::propagateElementsPerThreadForward(
1386+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1387+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1388+
llvm::raw_ostream &errs) {
1389+
// ReadOp: use AttrBasedElementsPerThreadOpTrait logic for register result
1390+
// but ignore memory operand (operands[0])
1391+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1392+
if (!elementsPerThread)
1393+
return mlir::ChangeResult::NoChange;
1394+
1395+
// Only propagate to results (register), not from memory operand
1396+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1397+
wave::ElementsPerThreadLatticeValue(*elementsPerThread),
1398+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), // empty immutable (ignore memory operand)
1399+
resultElements, "elements_per_thread attribute", "", "result", errs);
1400+
}
1401+
1402+
llvm::FailureOr<mlir::ChangeResult> wave::ReadOp::propagateElementsPerThreadBackward(
1403+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1404+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1405+
llvm::raw_ostream &errs) {
1406+
// ReadOp: use AttrBasedElementsPerThreadOpTrait logic for register result
1407+
// but ignore memory operand (operandElements[0])
1408+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1409+
if (!elementsPerThread)
1410+
return mlir::ChangeResult::NoChange;
1411+
1412+
// Only check consistency with results (register), not memory operand
1413+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1414+
wave::ElementsPerThreadLatticeValue(*elementsPerThread),
1415+
resultElements,
1416+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>(), // empty mutable (ignore memory operand)
1417+
"elements_per_thread attribute", "result", "", errs);
1418+
}
1419+
13341420
//-----------------------------------------------------------------------------
13351421
// RegisterOp
13361422
//-----------------------------------------------------------------------------
@@ -1456,6 +1542,34 @@ llvm::LogicalResult wave::WriteOp::setIndexFromLattices(
14561542
return llvm::success();
14571543
}
14581544

1545+
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateElementsPerThreadForward(
1546+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1547+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1548+
llvm::raw_ostream &errs) {
1549+
// WriteOp has no results, so forward propagation is NoChange
1550+
return mlir::ChangeResult::NoChange;
1551+
}
1552+
1553+
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateElementsPerThreadBackward(
1554+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1555+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1556+
llvm::raw_ostream &errs) {
1557+
// WriteOp: use AttrBasedElementsPerThreadOpTrait logic for register operand
1558+
// but ignore memory operand (operandElements[1])
1559+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1560+
if (!elementsPerThread)
1561+
return mlir::ChangeResult::NoChange;
1562+
1563+
// Only propagate to operands[0] (register), not operands[1] (memory)
1564+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> registerOperand =
1565+
operandElements.take_front(1); // Only operands[0]
1566+
1567+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1568+
wave::ElementsPerThreadLatticeValue(*elementsPerThread),
1569+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), // empty immutable (no results)
1570+
registerOperand, "elements_per_thread attribute", "", "register operand", errs);
1571+
}
1572+
14591573
//-----------------------------------------------------------------------------
14601574
// YieldOp
14611575
//-----------------------------------------------------------------------------

water/test/Dialect/Wave/propagate-elements-per-thread.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func.func @missing_elements_per_thread(%mem: !wave.tensor<[@M] of f16, <global>>
100100
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
101101
func.func @read_write_conflict(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
102102
%reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <global>>) -> !wave.tensor<[@M] of f16, <register>>
103-
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
103+
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
104104
wave.write %reg, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <global>>
105105
return
106106
}
@@ -112,7 +112,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
112112
func.func @read_write_conflict_indirect(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
113113
%reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <global>>) -> !wave.tensor<[@M] of f16, <register>>
114114
%val = wave.exp2 %reg : (!wave.tensor<[@M] of f16, <register>>) -> !wave.tensor<[@M] of f16, <register>>
115-
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
115+
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
116116
wave.write %reg, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <global>>
117117
return
118118
}

0 commit comments

Comments
 (0)