From 561433334a2e079e7732a25d7d854b491cbd3659 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Sat, 20 Dec 2025 23:14:14 +0100 Subject: [PATCH 1/5] Add WaveElementsPerThreadOpInterface to MmaOp Implements elements per thread propagation for MMA operations. Fixes #608. Signed-off-by: tyb0807 --- .../include/water/Dialect/Wave/IR/WaveOps.td | 6 + water/lib/Dialect/Wave/IR/CMakeLists.txt | 1 + water/lib/Dialect/Wave/IR/WaveOps.cpp | 86 ++++++++++ water/test/Dialect/Wave/ops.mlir | 156 ++++++++++++++++++ .../Wave/propagate-elements-per-thread.mlir | 41 +++++ 5 files changed, 290 insertions(+) diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index bfa94db8c..734528e4b 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -110,6 +110,7 @@ def Exp2Op : UnaryWaveOp<"exp2"> { def MmaOp : WaveOp<"mma", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, WaveArithmeticOpDoc { @@ -130,6 +131,11 @@ def MmaOp : WaveOp<"mma", "$lhs `,` $rhs `,` $accumulator " # commonArgumentsSyntax # "attr-dict `:`" "functional-type(operands, results)"; let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Compute the expected elements per thread for this MMA operation. + unsigned computeElementsPerThread(); + }]; } //----------------------------------------------------------------------------- diff --git a/water/lib/Dialect/Wave/IR/CMakeLists.txt b/water/lib/Dialect/Wave/IR/CMakeLists.txt index 586578028..7cd7377be 100644 --- a/water/lib/Dialect/Wave/IR/CMakeLists.txt +++ b/water/lib/Dialect/Wave/IR/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRWaveDialect MLIRIR MLIRControlFlowInterfaces MLIRFunctionInterfaces + MLIRFuncDialect ) # Install the Wave dialect library so Python can find it at runtime. diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index e547bd9ca..22086a9b4 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -6,6 +6,7 @@ #include "water/Dialect/Wave/IR/WaveOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -1138,6 +1139,91 @@ LogicalResult MmaOp::verify() { accumulatorType.getElementType()); } +/// Compute the expected elements per thread for this MMA operation. +/// Extracts threadsPerWave from the parent function's hardware constraint. +unsigned wave::MmaOp::computeElementsPerThread() { + auto kind = getKind(); + if (!kind) { + return 0; + } + wave::WaveMmaSpec spec = + wave::WaveMmaKindAttr::getSpec(getContext(), *kind); + + // Extract threads per wave from hardware constraint. + // Default fallback. + unsigned threadsPerWave = 64; + if (auto parentFunc = getOperation()->getParentOfType()) { + if (auto constraints = + parentFunc->getAttrOfType("wave.constraints")) { + for (mlir::Attribute constraint : constraints) { + if (auto hardwareConstraint = + llvm::dyn_cast(constraint)) { + threadsPerWave = hardwareConstraint.getThreadsPerWave(); + break; + } + } + } + } + + unsigned totalElements = spec.m * spec.n; + return totalElements / threadsPerWave; +} + +llvm::FailureOr +wave::MmaOp::propagateElementsPerThreadForward( + llvm::ArrayRef operandElements, + llvm::MutableArrayRef resultElements, + llvm::raw_ostream &errs) { + unsigned expectedElementsPerThread = computeElementsPerThread(); + wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread); + return wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedResult, llvm::ArrayRef(), + resultElements, "computed from MMA kind", "", "result", errs); +} + +llvm::FailureOr +wave::MmaOp::propagateElementsPerThreadBackward( + llvm::MutableArrayRef operandElements, + llvm::ArrayRef, + llvm::raw_ostream &errs) { + // For MMA, the accumulator should have the same elements per thread as the + // result. The LHS and RHS operands may have different constraints based on + // their dimensions. + unsigned expectedElementsPerThread = computeElementsPerThread(); + wave::ElementsPerThreadLatticeValue expectedAccumulator( + expectedElementsPerThread); + + unsigned accumulatorOperandNumber = + getAccumulatorMutable().getOperandNumber(); + + // First, validate that LHS and RHS operands have concrete elements_per_thread + // values. We don't propagate to them, but we ensure they've been properly + // initialized. + if (operandElements.size() >= 3) { + for (unsigned i = 0; i < 2; ++i) { // LHS (0) and RHS (1) operands + if (operandElements[i].isBottom()) { + errs << "MMA operand #" << i << " ("; + errs << (i == 0 ? "LHS" : "RHS"); + errs << ") has uninitialized elements_per_thread"; + return mlir::failure(); + } + } + } + + // Then propagate to the accumulator operand. + if (operandElements.size() > accumulatorOperandNumber) { + llvm::MutableArrayRef accumulatorOnly = + operandElements.slice(accumulatorOperandNumber, 1); + + return wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedAccumulator, + llvm::ArrayRef(), accumulatorOnly, + "computed from MMA kind", "", "accumulator operand", errs); + } + + return mlir::ChangeResult::NoChange; +} + //----------------------------------------------------------------------------- // ReadOp //----------------------------------------------------------------------------- diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index 4022b6b82..fa5f8da3c 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -135,6 +135,162 @@ func.func @register_with_hyperparameter() attributes {hyperparameters = #wave.hy return } +#hw_constraint = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_interface +func.func @mma_elements_per_thread_interface() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values - elements_per_thread determined by MMA backward propagation. + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/64 threads = 16 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_32_threads = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 16}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_32_threads +func.func @mma_elements_per_thread_32_threads() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, + wave.constraints = [#hw_constraint_32_threads] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values - elements_per_thread determined by MMA backward propagation. + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_16x16x16_f16: 16*16/32 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_128_threads = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_128_threads +func.func @mma_elements_per_thread_128_threads() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint_128_threads] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values - elements_per_thread determined by MMA backward propagation. + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/128 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_interface +func.func @mma_elements_per_thread_interface() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values with explicit elements_per_thread. + %lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init {elements_per_thread = 16} : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/64 threads = 16 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_32_threads = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 16}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_32_threads +func.func @mma_elements_per_thread_32_threads() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, + wave.constraints = [#hw_constraint_32_threads] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values with explicit elements_per_thread. + %lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_16x16x16_f16: 16*16/32 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_128_threads = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_128_threads +func.func @mma_elements_per_thread_128_threads() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint_128_threads] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values with explicit elements_per_thread. + %lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/128 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + // CHECK-LABEL: @allocate func.func @allocate() -> !wave.tensor<[@M, @N] of bf16, > { // CHECK: wave.allocate diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index d64db241f..392fbb45d 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -147,6 +147,7 @@ func.func @unsupported_op() attributes {wave.hyperparameters = #wave.hyperparame } } + // ----- // CHECK: #wave.normal_form @@ -158,3 +159,43 @@ module { return } } + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { +func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>} { + // LHS without elements_per_thread - this will remain uninitialized. + %lhs_init = arith.constant 0.0 : f16 + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + + // RHS properly initialized through read operation. + %rhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@N, @K] of f16, >) -> !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized through read operation. + %acc = wave.read %mem2 {elements_per_thread = 4} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // expected-error @below {{failed to propagate elements per thread backward: MMA operand #0 (LHS) has uninitialized elements_per_thread}} + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { +func.func @mma_uninitialized_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>} { + // LHS properly initialized through read operation. + %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > + + // RHS without elements_per_thread - this will remain uninitialized. + %rhs_init = arith.constant 0.0 : f16 + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized through read operation. + %acc = wave.read %mem2 {elements_per_thread = 4} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // expected-error @below {{failed to propagate elements per thread backward: MMA operand #1 (RHS) has uninitialized elements_per_thread}} + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} +} From 18f684921ca0aad63708f1d951e15ea926b4c52b Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Tue, 23 Dec 2025 01:04:58 +0100 Subject: [PATCH 2/5] Address reviews Signed-off-by: tyb0807 --- water/lib/Dialect/Wave/IR/WaveOps.cpp | 47 ++++++++++++--------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 22086a9b4..fba236600 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1140,7 +1140,8 @@ LogicalResult MmaOp::verify() { } /// Compute the expected elements per thread for this MMA operation. -/// Extracts threadsPerWave from the parent function's hardware constraint. +/// Extracts threadsPerWave from ancestor operations with hardware constraints. +/// Returns 0 if no constraints are found. unsigned wave::MmaOp::computeElementsPerThread() { auto kind = getKind(); if (!kind) { @@ -1149,24 +1150,24 @@ unsigned wave::MmaOp::computeElementsPerThread() { wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec(getContext(), *kind); - // Extract threads per wave from hardware constraint. - // Default fallback. - unsigned threadsPerWave = 64; - if (auto parentFunc = getOperation()->getParentOfType()) { - if (auto constraints = - parentFunc->getAttrOfType("wave.constraints")) { + // Extract threads per wave from hardware constraint by walking up the ancestry. + mlir::Operation *op = getOperation(); + while (op) { + if (auto constraints = op->getAttrOfType( + wave::WaveDialect::kWaveConstraintsAttrName)) { for (mlir::Attribute constraint : constraints) { if (auto hardwareConstraint = llvm::dyn_cast(constraint)) { - threadsPerWave = hardwareConstraint.getThreadsPerWave(); - break; + unsigned totalElements = spec.m * spec.n; + return totalElements / hardwareConstraint.getThreadsPerWave(); } } } + op = op->getParentOp(); } - unsigned totalElements = spec.m * spec.n; - return totalElements / threadsPerWave; + // Return 0 to indicate failure if no constraints found. + return 0; } llvm::FailureOr @@ -1175,6 +1176,10 @@ wave::MmaOp::propagateElementsPerThreadForward( llvm::MutableArrayRef resultElements, llvm::raw_ostream &errs) { unsigned expectedElementsPerThread = computeElementsPerThread(); + if (expectedElementsPerThread == 0) { + errs << "MMA operation has no hardware constraints available"; + return mlir::failure(); + } wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread); return wave::detail::checkAndPropagateElementsPerThreadFromConstant( expectedResult, llvm::ArrayRef(), @@ -1190,27 +1195,17 @@ wave::MmaOp::propagateElementsPerThreadBackward( // result. The LHS and RHS operands may have different constraints based on // their dimensions. unsigned expectedElementsPerThread = computeElementsPerThread(); + if (expectedElementsPerThread == 0) { + errs << "MMA operation has no hardware constraints available"; + return mlir::failure(); + } wave::ElementsPerThreadLatticeValue expectedAccumulator( expectedElementsPerThread); unsigned accumulatorOperandNumber = getAccumulatorMutable().getOperandNumber(); - // First, validate that LHS and RHS operands have concrete elements_per_thread - // values. We don't propagate to them, but we ensure they've been properly - // initialized. - if (operandElements.size() >= 3) { - for (unsigned i = 0; i < 2; ++i) { // LHS (0) and RHS (1) operands - if (operandElements[i].isBottom()) { - errs << "MMA operand #" << i << " ("; - errs << (i == 0 ? "LHS" : "RHS"); - errs << ") has uninitialized elements_per_thread"; - return mlir::failure(); - } - } - } - - // Then propagate to the accumulator operand. + // Propagate to the accumulator operand. if (operandElements.size() > accumulatorOperandNumber) { llvm::MutableArrayRef accumulatorOnly = operandElements.slice(accumulatorOperandNumber, 1); From d0b0c60622cdd84336b481d46125bce8dc216d68 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Tue, 23 Dec 2025 01:08:22 +0100 Subject: [PATCH 3/5] Fix tests Signed-off-by: tyb0807 --- water/lib/Dialect/Wave/IR/WaveOps.cpp | 19 ++++++++++++--- water/test/Dialect/Wave/ops.mlir | 24 +++++++++---------- .../Wave/propagate-elements-per-thread.mlir | 4 ++-- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index fba236600..fc4d5ad2b 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1147,10 +1147,10 @@ unsigned wave::MmaOp::computeElementsPerThread() { if (!kind) { return 0; } - wave::WaveMmaSpec spec = - wave::WaveMmaKindAttr::getSpec(getContext(), *kind); + wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec(getContext(), *kind); - // Extract threads per wave from hardware constraint by walking up the ancestry. + // Extract threads per wave from hardware constraint by walking up the + // ancestry. mlir::Operation *op = getOperation(); while (op) { if (auto constraints = op->getAttrOfType( @@ -1205,6 +1205,19 @@ wave::MmaOp::propagateElementsPerThreadBackward( unsigned accumulatorOperandNumber = getAccumulatorMutable().getOperandNumber(); + // Validate that LHS and RHS operands have concrete elements_per_thread + // values. We don't propagate to them, but we check they've been properly + // initialized. + // LHS (0) and RHS (1) operands. + for (unsigned i = 0; i < 2 && i < operandElements.size(); ++i) { + if (operandElements[i].isBottom()) { + errs << "MMA operand #" << i << " ("; + errs << (i == 0 ? "LHS" : "RHS"); + errs << ") has uninitialized elements_per_thread"; + return mlir::failure(); + } + } + // Propagate to the accumulator operand. if (operandElements.size() > accumulatorOperandNumber) { llvm::MutableArrayRef accumulatorOnly = diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index fa5f8da3c..f13900608 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -213,16 +213,16 @@ func.func @mma_elements_per_thread_128_threads() attributes { return } -#hw_constraint = #wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 8}, max_bits_per_load = 128> -// CHECK-LABEL: @mma_elements_per_thread_interface -func.func @mma_elements_per_thread_interface() attributes { +// CHECK-LABEL: @mma_elements_per_thread_interface_explicit +func.func @mma_elements_per_thread_interface_explicit() attributes { wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, - wave.constraints = [#hw_constraint] + wave.constraints = [#hw_constraint_interface_alt] } { %lhs_init = arith.constant 1.0 : f16 %rhs_init = arith.constant 2.0 : f16 @@ -239,16 +239,16 @@ func.func @mma_elements_per_thread_interface() attributes { return } -#hw_constraint_32_threads = #wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128> -// CHECK-LABEL: @mma_elements_per_thread_32_threads -func.func @mma_elements_per_thread_32_threads() attributes { +// CHECK-LABEL: @mma_elements_per_thread_32_threads_explicit +func.func @mma_elements_per_thread_32_threads_explicit() attributes { wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, - wave.constraints = [#hw_constraint_32_threads] + wave.constraints = [#hw_constraint_32_threads_explicit] } { %lhs_init = arith.constant 1.0 : f16 %rhs_init = arith.constant 2.0 : f16 @@ -265,16 +265,16 @@ func.func @mma_elements_per_thread_32_threads() attributes { return } -#hw_constraint_128_threads = #wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 8}, max_bits_per_load = 128> -// CHECK-LABEL: @mma_elements_per_thread_128_threads -func.func @mma_elements_per_thread_128_threads() attributes { +// CHECK-LABEL: @mma_elements_per_thread_128_threads_explicit +func.func @mma_elements_per_thread_128_threads_explicit() attributes { wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, - wave.constraints = [#hw_constraint_128_threads] + wave.constraints = [#hw_constraint_128_threads_explicit] } { %lhs_init = arith.constant 1.0 : f16 %rhs_init = arith.constant 2.0 : f16 diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index 392fbb45d..cecda85a2 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -163,7 +163,7 @@ module { // ----- module attributes {wave.normal_form = #wave.normal_form} { -func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>} { +func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS without elements_per_thread - this will remain uninitialized. %lhs_init = arith.constant 0.0 : f16 %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > @@ -183,7 +183,7 @@ func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, >, // ----- module attributes {wave.normal_form = #wave.normal_form} { -func.func @mma_uninitialized_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>} { +func.func @mma_uninitialized_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS properly initialized through read operation. %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > From a87e7325da0833ae0f55a80fa626d15249c40251 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Tue, 23 Dec 2025 11:11:39 +0100 Subject: [PATCH 4/5] More comments Signed-off-by: tyb0807 --- water/lib/Dialect/Wave/IR/CMakeLists.txt | 1 - water/lib/Dialect/Wave/IR/WaveOps.cpp | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/water/lib/Dialect/Wave/IR/CMakeLists.txt b/water/lib/Dialect/Wave/IR/CMakeLists.txt index 7cd7377be..586578028 100644 --- a/water/lib/Dialect/Wave/IR/CMakeLists.txt +++ b/water/lib/Dialect/Wave/IR/CMakeLists.txt @@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRWaveDialect MLIRIR MLIRControlFlowInterfaces MLIRFunctionInterfaces - MLIRFuncDialect ) # Install the Wave dialect library so Python can find it at runtime. diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index fc4d5ad2b..d83180436 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -6,7 +6,6 @@ #include "water/Dialect/Wave/IR/WaveOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -1143,7 +1142,7 @@ LogicalResult MmaOp::verify() { /// Extracts threadsPerWave from ancestor operations with hardware constraints. /// Returns 0 if no constraints are found. unsigned wave::MmaOp::computeElementsPerThread() { - auto kind = getKind(); + wave::WaveMmaKind kind = getKind(); if (!kind) { return 0; } @@ -1207,17 +1206,23 @@ wave::MmaOp::propagateElementsPerThreadBackward( // Validate that LHS and RHS operands have concrete elements_per_thread // values. We don't propagate to them, but we check they've been properly - // initialized. + // initialized. During analysis initialization, bottom values are acceptable - + // we return NoChange to let the analysis continue rather than failing. // LHS (0) and RHS (1) operands. + bool allLhsRhsInitialized = true; for (unsigned i = 0; i < 2 && i < operandElements.size(); ++i) { if (operandElements[i].isBottom()) { - errs << "MMA operand #" << i << " ("; - errs << (i == 0 ? "LHS" : "RHS"); - errs << ") has uninitialized elements_per_thread"; - return mlir::failure(); + allLhsRhsInitialized = false; + break; } } + // If LHS/RHS operands are still at bottom, return NoChange to allow + // the analysis to continue. Forward propagation will initialize them. + if (!allLhsRhsInitialized) { + return mlir::ChangeResult::NoChange; + } + // Propagate to the accumulator operand. if (operandElements.size() > accumulatorOperandNumber) { llvm::MutableArrayRef accumulatorOnly = From 2034cc617050b1684f844e7084cbded23041070a Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Tue, 23 Dec 2025 11:19:32 +0100 Subject: [PATCH 5/5] Fix Signed-off-by: tyb0807 --- .../include/water/Dialect/Wave/IR/WaveOps.td | 5 +- water/lib/Dialect/Wave/IR/WaveOps.cpp | 142 ++++++++++++------ .../Wave/propagate-elements-per-thread.mlir | 62 ++++++-- 3 files changed, 154 insertions(+), 55 deletions(-) diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 734528e4b..145dcc861 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -133,8 +133,9 @@ def MmaOp : WaveOp<"mma", let hasVerifier = 1; let extraClassDeclaration = [{ - /// Compute the expected elements per thread for this MMA operation. - unsigned computeElementsPerThread(); + /// Compute the expected elements per thread for a specific operand of this MMA operation. + /// Returns failure if no hardware constraints are available. + llvm::FailureOr computeElementsPerThreadForOperand(unsigned operandIndex); }]; } diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index d83180436..22f743b99 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1138,15 +1138,16 @@ LogicalResult MmaOp::verify() { accumulatorType.getElementType()); } -/// Compute the expected elements per thread for this MMA operation. -/// Extracts threadsPerWave from ancestor operations with hardware constraints. -/// Returns 0 if no constraints are found. -unsigned wave::MmaOp::computeElementsPerThread() { - wave::WaveMmaKind kind = getKind(); - if (!kind) { - return 0; - } - wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec(getContext(), *kind); +/// Compute the expected elements per thread for a specific MMA operand. +/// operandIndex: 0=LHS, 1=RHS, 2=Accumulator/Result +/// Returns failure if no constraints are found. +llvm::FailureOr +wave::MmaOp::computeElementsPerThreadForOperand(unsigned operandIndex) { + std::optional mmaKind = getKind(); + if (!mmaKind) + return mlir::failure(); + wave::WaveMmaSpec spec = + wave::WaveMmaKindAttr::getSpec(getContext(), *mmaKind); // Extract threads per wave from hardware constraint by walking up the // ancestry. @@ -1157,7 +1158,20 @@ unsigned wave::MmaOp::computeElementsPerThread() { for (mlir::Attribute constraint : constraints) { if (auto hardwareConstraint = llvm::dyn_cast(constraint)) { - unsigned totalElements = spec.m * spec.n; + unsigned totalElements; + switch (operandIndex) { + case 0: // LHS: M x K + totalElements = spec.m * spec.k; + break; + case 1: // RHS: N x K + totalElements = spec.n * spec.k; + break; + case 2: // Accumulator/Result: M x N + totalElements = spec.m * spec.n; + break; + default: + return mlir::failure(); + } return totalElements / hardwareConstraint.getThreadsPerWave(); } } @@ -1165,8 +1179,8 @@ unsigned wave::MmaOp::computeElementsPerThread() { op = op->getParentOp(); } - // Return 0 to indicate failure if no constraints found. - return 0; + // Return failure if no constraints found. + return mlir::failure(); } llvm::FailureOr @@ -1174,11 +1188,14 @@ wave::MmaOp::propagateElementsPerThreadForward( llvm::ArrayRef operandElements, llvm::MutableArrayRef resultElements, llvm::raw_ostream &errs) { - unsigned expectedElementsPerThread = computeElementsPerThread(); - if (expectedElementsPerThread == 0) { + llvm::FailureOr expectedElementsPerThreadResult = + computeElementsPerThreadForOperand( + getAccumulatorMutable().getOperandNumber()); + if (llvm::failed(expectedElementsPerThreadResult)) { errs << "MMA operation has no hardware constraints available"; return mlir::failure(); } + unsigned expectedElementsPerThread = *expectedElementsPerThreadResult; wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread); return wave::detail::checkAndPropagateElementsPerThreadFromConstant( expectedResult, llvm::ArrayRef(), @@ -1193,48 +1210,87 @@ wave::MmaOp::propagateElementsPerThreadBackward( // For MMA, the accumulator should have the same elements per thread as the // result. The LHS and RHS operands may have different constraints based on // their dimensions. - unsigned expectedElementsPerThread = computeElementsPerThread(); - if (expectedElementsPerThread == 0) { + // MMA operation always has exactly 3 operands: LHS, RHS, Accumulator + assert(operandElements.size() == 3 && + "MMA operation must have exactly 3 operands"); + + unsigned lhsOperandNumber = getLhsMutable().getOperandNumber(); + unsigned rhsOperandNumber = getRhsMutable().getOperandNumber(); + unsigned accumulatorOperandNumber = + getAccumulatorMutable().getOperandNumber(); + + // Compute expected elements per thread for each operand + llvm::FailureOr expectedLhsElementsPerThreadResult = + computeElementsPerThreadForOperand(lhsOperandNumber); + llvm::FailureOr expectedRhsElementsPerThreadResult = + computeElementsPerThreadForOperand(rhsOperandNumber); + llvm::FailureOr expectedAccumulatorElementsPerThreadResult = + computeElementsPerThreadForOperand(accumulatorOperandNumber); + + if (llvm::failed(expectedLhsElementsPerThreadResult) || + llvm::failed(expectedRhsElementsPerThreadResult) || + llvm::failed(expectedAccumulatorElementsPerThreadResult)) { errs << "MMA operation has no hardware constraints available"; return mlir::failure(); } + + unsigned expectedLhsElementsPerThread = *expectedLhsElementsPerThreadResult; + unsigned expectedRhsElementsPerThread = *expectedRhsElementsPerThreadResult; + unsigned expectedAccumulatorElementsPerThread = + *expectedAccumulatorElementsPerThreadResult; + + wave::ElementsPerThreadLatticeValue expectedLhs(expectedLhsElementsPerThread); + wave::ElementsPerThreadLatticeValue expectedRhs(expectedRhsElementsPerThread); wave::ElementsPerThreadLatticeValue expectedAccumulator( - expectedElementsPerThread); + expectedAccumulatorElementsPerThread); - unsigned accumulatorOperandNumber = - getAccumulatorMutable().getOperandNumber(); + // Propagate elements_per_thread to LHS operand using the helper function + llvm::MutableArrayRef lhsOnly = + operandElements.slice(lhsOperandNumber, 1); - // Validate that LHS and RHS operands have concrete elements_per_thread - // values. We don't propagate to them, but we check they've been properly - // initialized. During analysis initialization, bottom values are acceptable - - // we return NoChange to let the analysis continue rather than failing. - // LHS (0) and RHS (1) operands. - bool allLhsRhsInitialized = true; - for (unsigned i = 0; i < 2 && i < operandElements.size(); ++i) { - if (operandElements[i].isBottom()) { - allLhsRhsInitialized = false; - break; - } + llvm::FailureOr lhsResult = + wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedLhs, llvm::ArrayRef(), + lhsOnly, "computed from MMA kind", "", "LHS operand", errs); + + if (llvm::failed(lhsResult)) { + return llvm::failure(); } - // If LHS/RHS operands are still at bottom, return NoChange to allow - // the analysis to continue. Forward propagation will initialize them. - if (!allLhsRhsInitialized) { - return mlir::ChangeResult::NoChange; + // Propagate elements_per_thread to RHS operand using the helper function + llvm::MutableArrayRef rhsOnly = + operandElements.slice(rhsOperandNumber, 1); + + llvm::FailureOr rhsResult = + wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedRhs, llvm::ArrayRef(), + rhsOnly, "computed from MMA kind", "", "RHS operand", errs); + + if (llvm::failed(rhsResult)) { + return mlir::failure(); } // Propagate to the accumulator operand. - if (operandElements.size() > accumulatorOperandNumber) { - llvm::MutableArrayRef accumulatorOnly = - operandElements.slice(accumulatorOperandNumber, 1); - - return wave::detail::checkAndPropagateElementsPerThreadFromConstant( - expectedAccumulator, - llvm::ArrayRef(), accumulatorOnly, - "computed from MMA kind", "", "accumulator operand", errs); + llvm::MutableArrayRef accumulatorOnly = + operandElements.slice(accumulatorOperandNumber, 1); + + llvm::FailureOr accumulatorResult = + wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedAccumulator, + llvm::ArrayRef(), + accumulatorOnly, "computed from MMA kind", "", "accumulator operand", + errs); + + if (llvm::failed(accumulatorResult)) { + return mlir::failure(); } - return mlir::ChangeResult::NoChange; + // Return Change if any operand changed + return (*lhsResult == mlir::ChangeResult::Change || + *rhsResult == mlir::ChangeResult::Change || + *accumulatorResult == mlir::ChangeResult::Change) + ? mlir::ChangeResult::Change + : mlir::ChangeResult::NoChange; } //----------------------------------------------------------------------------- diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index cecda85a2..d9ef3382b 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -163,18 +163,18 @@ module { // ----- module attributes {wave.normal_form = #wave.normal_form} { -func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { - // LHS without elements_per_thread - this will remain uninitialized. +func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // LHS without elements_per_thread - will be computed from RHS + MMA constraints. %lhs_init = arith.constant 0.0 : f16 %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > // RHS properly initialized through read operation. - %rhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@N, @K] of f16, >) -> !wave.tensor<[@N, @K] of f16, > + %rhs = wave.read %mem1 {elements_per_thread = 8} : (!wave.tensor<[@N, @K] of f16, >) -> !wave.tensor<[@N, @K] of f16, > // ACC properly initialized through read operation. - %acc = wave.read %mem2 {elements_per_thread = 4} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > - // expected-error @below {{failed to propagate elements per thread backward: MMA operand #0 (LHS) has uninitialized elements_per_thread}} + // LHS elements_per_thread computed via MMA backward propagation %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > return } @@ -183,19 +183,61 @@ func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, >, // ----- module attributes {wave.normal_form = #wave.normal_form} { -func.func @mma_uninitialized_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { +func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS properly initialized through read operation. - %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > + %lhs = wave.read %mem1 {elements_per_thread = 8} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > - // RHS without elements_per_thread - this will remain uninitialized. + // RHS without elements_per_thread - will be computed from LHS + MMA constraints. %rhs_init = arith.constant 0.0 : f16 %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > // ACC properly initialized through read operation. - %acc = wave.read %mem2 {elements_per_thread = 4} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > - // expected-error @below {{failed to propagate elements per thread backward: MMA operand #1 (RHS) has uninitialized elements_per_thread}} + // RHS elements_per_thread computed via MMA backward propagation %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > return } } + +// ----- + +// Test MMA can compute both LHS and RHS when both are uninitialized +module attributes {wave.normal_form = #wave.normal_form} { + func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@N, @K] of f16, >, %mem3: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // Both LHS and RHS without elements_per_thread - can compute from MMA formulas + %lhs_init = arith.constant 0.0 : f16 + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs_init = arith.constant 0.0 : f16 + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized through read operation. + %acc = wave.read %mem3 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // With proper MMA formulas, we can now compute both LHS and RHS from constraints, + // so this should succeed instead of failing + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return + } +} + +// ----- + +// Test MMA error when operand has wrong elements_per_thread +module attributes {wave.normal_form = #wave.normal_form} { + func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // LHS with wrong elements_per_thread (should be 8, not 4) + %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > + + // RHS without elements_per_thread - will be computed from MMA constraints. + %rhs_init = arith.constant 0.0 : f16 + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized + %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}} + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return + } +}