@@ -292,6 +292,57 @@ llvm::FailureOr<mlir::ChangeResult> wave::MmaOp::propagateBackward(
292292 " result" , " accumulator" , errs);
293293}
294294
295+ // Get the MMA result elements per thread for a given MMA kind
296+ static unsigned getMmaResultElementsPerThreadForMmaKind (mlir::MLIRContext *context, wave::WaveMmaKind kind) {
297+ // Get the MMA specification (M, N, K dimensions and element types)
298+ wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec (context, kind);
299+
300+ // Elements per thread = (M × N) / threads_per_wave
301+ // AMD GPU waves have 64 threads
302+ constexpr unsigned threadsPerWave = 64 ;
303+ unsigned totalElements = spec.m * spec.n ;
304+ return totalElements / threadsPerWave;
305+ }
306+
307+ llvm::FailureOr<mlir::ChangeResult> wave::MmaOp::propagateElementsPerThreadForward (
308+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
309+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
310+ llvm::raw_ostream &errs) {
311+ // For MMA, the result elements per thread is determined by the MMA kind, not the operands
312+ unsigned expectedElementsPerThread = getMmaResultElementsPerThreadForMmaKind (getContext (), getKind ());
313+ wave::ElementsPerThreadLatticeValue expectedResult (expectedElementsPerThread);
314+
315+ // Propagate to result
316+ auto joined = wave::ElementsPerThreadLatticeValue::join (expectedResult, resultElements[0 ]);
317+ if (joined.isTop () && !expectedResult.isTop () && !resultElements[0 ].isTop ()) {
318+ errs << " mismatched elements per thread for MMA result: expected " << expectedElementsPerThread
319+ << " elements per thread for MMA kind " << getKind () << " but got (" ;
320+ resultElements[0 ].print (errs);
321+ errs << " )" ;
322+ return mlir::failure ();
323+ }
324+
325+ if (joined != resultElements[0 ]) {
326+ resultElements[0 ] = joined;
327+ return mlir::ChangeResult::Change;
328+ }
329+
330+ return mlir::ChangeResult::NoChange;
331+ }
332+
333+ llvm::FailureOr<mlir::ChangeResult> wave::MmaOp::propagateElementsPerThreadBackward (
334+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
335+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
336+ llvm::raw_ostream &errs) {
337+ // For MMA, operands and result elements per thread may be different
338+ // The result is determined by the MMA kind, operands can have their own values
339+ // We don't propagate backwards for MMA since operands and result have independent constraints
340+ (void )operandElements; // Avoid unused parameter warning
341+ (void )resultElements;
342+ (void )errs;
343+ return mlir::ChangeResult::NoChange;
344+ }
345+
295346// Set the value of `lattice` to `newLattice` and return whether a change
296347// happened. Note that this does NOT verify whether the lattice change goes into
297348// the direction of top or bottom.
@@ -1331,6 +1382,41 @@ LogicalResult ReadOp::verify() {
13311382 bounds.getMapping ());
13321383}
13331384
1385+ llvm::FailureOr<mlir::ChangeResult> wave::ReadOp::propagateElementsPerThreadForward (
1386+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1387+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1388+ llvm::raw_ostream &errs) {
1389+ // ReadOp: use AttrBasedElementsPerThreadOpTrait logic for register result
1390+ // but ignore memory operand (operands[0])
1391+ std::optional<int64_t > elementsPerThread = getElementsPerThread ();
1392+ if (!elementsPerThread)
1393+ return mlir::ChangeResult::NoChange;
1394+
1395+ // Only propagate to results (register), not from memory operand
1396+ return wave::detail::checkAndPropagateElementsPerThreadFromConstant (
1397+ wave::ElementsPerThreadLatticeValue (*elementsPerThread),
1398+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), // empty immutable (ignore memory operand)
1399+ resultElements, " elements_per_thread attribute" , " " , " result" , errs);
1400+ }
1401+
1402+ llvm::FailureOr<mlir::ChangeResult> wave::ReadOp::propagateElementsPerThreadBackward (
1403+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1404+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1405+ llvm::raw_ostream &errs) {
1406+ // ReadOp: use AttrBasedElementsPerThreadOpTrait logic for register result
1407+ // but ignore memory operand (operandElements[0])
1408+ std::optional<int64_t > elementsPerThread = getElementsPerThread ();
1409+ if (!elementsPerThread)
1410+ return mlir::ChangeResult::NoChange;
1411+
1412+ // Only check consistency with results (register), not memory operand
1413+ return wave::detail::checkAndPropagateElementsPerThreadFromConstant (
1414+ wave::ElementsPerThreadLatticeValue (*elementsPerThread),
1415+ resultElements,
1416+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>(), // empty mutable (ignore memory operand)
1417+ " elements_per_thread attribute" , " result" , " " , errs);
1418+ }
1419+
13341420// -----------------------------------------------------------------------------
13351421// RegisterOp
13361422// -----------------------------------------------------------------------------
@@ -1456,6 +1542,34 @@ llvm::LogicalResult wave::WriteOp::setIndexFromLattices(
14561542 return llvm::success ();
14571543}
14581544
1545+ llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateElementsPerThreadForward (
1546+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1547+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1548+ llvm::raw_ostream &errs) {
1549+ // WriteOp has no results, so forward propagation is NoChange
1550+ return mlir::ChangeResult::NoChange;
1551+ }
1552+
1553+ llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateElementsPerThreadBackward (
1554+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1555+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1556+ llvm::raw_ostream &errs) {
1557+ // WriteOp: use AttrBasedElementsPerThreadOpTrait logic for register operand
1558+ // but ignore memory operand (operandElements[1])
1559+ std::optional<int64_t > elementsPerThread = getElementsPerThread ();
1560+ if (!elementsPerThread)
1561+ return mlir::ChangeResult::NoChange;
1562+
1563+ // Only propagate to operands[0] (register), not operands[1] (memory)
1564+ llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> registerOperand =
1565+ operandElements.take_front (1 ); // Only operands[0]
1566+
1567+ return wave::detail::checkAndPropagateElementsPerThreadFromConstant (
1568+ wave::ElementsPerThreadLatticeValue (*elementsPerThread),
1569+ llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), // empty immutable (no results)
1570+ registerOperand, " elements_per_thread attribute" , " " , " register operand" , errs);
1571+ }
1572+
14591573// -----------------------------------------------------------------------------
14601574// YieldOp
14611575// -----------------------------------------------------------------------------
0 commit comments