diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index bfa94db8c..145dcc861 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -110,6 +110,7 @@ def Exp2Op : UnaryWaveOp<"exp2"> { def MmaOp : WaveOp<"mma", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, WaveArithmeticOpDoc { @@ -130,6 +131,12 @@ def MmaOp : WaveOp<"mma", "$lhs `,` $rhs `,` $accumulator " # commonArgumentsSyntax # "attr-dict `:`" "functional-type(operands, results)"; let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Compute the expected elements per thread for a specific operand of this MMA operation. + /// Returns failure if no hardware constraints are available. + llvm::FailureOr computeElementsPerThreadForOperand(unsigned operandIndex); + }]; } //----------------------------------------------------------------------------- diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index e547bd9ca..22f743b99 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1138,6 +1138,161 @@ LogicalResult MmaOp::verify() { accumulatorType.getElementType()); } +/// Compute the expected elements per thread for a specific MMA operand. +/// operandIndex: 0=LHS, 1=RHS, 2=Accumulator/Result +/// Returns failure if no constraints are found. +llvm::FailureOr +wave::MmaOp::computeElementsPerThreadForOperand(unsigned operandIndex) { + std::optional mmaKind = getKind(); + if (!mmaKind) + return mlir::failure(); + wave::WaveMmaSpec spec = + wave::WaveMmaKindAttr::getSpec(getContext(), *mmaKind); + + // Extract threads per wave from hardware constraint by walking up the + // ancestry. + mlir::Operation *op = getOperation(); + while (op) { + if (auto constraints = op->getAttrOfType( + wave::WaveDialect::kWaveConstraintsAttrName)) { + for (mlir::Attribute constraint : constraints) { + if (auto hardwareConstraint = + llvm::dyn_cast(constraint)) { + unsigned totalElements; + switch (operandIndex) { + case 0: // LHS: M x K + totalElements = spec.m * spec.k; + break; + case 1: // RHS: N x K + totalElements = spec.n * spec.k; + break; + case 2: // Accumulator/Result: M x N + totalElements = spec.m * spec.n; + break; + default: + return mlir::failure(); + } + return totalElements / hardwareConstraint.getThreadsPerWave(); + } + } + } + op = op->getParentOp(); + } + + // Return failure if no constraints found. + return mlir::failure(); +} + +llvm::FailureOr +wave::MmaOp::propagateElementsPerThreadForward( + llvm::ArrayRef operandElements, + llvm::MutableArrayRef resultElements, + llvm::raw_ostream &errs) { + llvm::FailureOr expectedElementsPerThreadResult = + computeElementsPerThreadForOperand( + getAccumulatorMutable().getOperandNumber()); + if (llvm::failed(expectedElementsPerThreadResult)) { + errs << "MMA operation has no hardware constraints available"; + return mlir::failure(); + } + unsigned expectedElementsPerThread = *expectedElementsPerThreadResult; + wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread); + return wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedResult, llvm::ArrayRef(), + resultElements, "computed from MMA kind", "", "result", errs); +} + +llvm::FailureOr +wave::MmaOp::propagateElementsPerThreadBackward( + llvm::MutableArrayRef operandElements, + llvm::ArrayRef, + llvm::raw_ostream &errs) { + // For MMA, the accumulator should have the same elements per thread as the + // result. The LHS and RHS operands may have different constraints based on + // their dimensions. + // MMA operation always has exactly 3 operands: LHS, RHS, Accumulator + assert(operandElements.size() == 3 && + "MMA operation must have exactly 3 operands"); + + unsigned lhsOperandNumber = getLhsMutable().getOperandNumber(); + unsigned rhsOperandNumber = getRhsMutable().getOperandNumber(); + unsigned accumulatorOperandNumber = + getAccumulatorMutable().getOperandNumber(); + + // Compute expected elements per thread for each operand + llvm::FailureOr expectedLhsElementsPerThreadResult = + computeElementsPerThreadForOperand(lhsOperandNumber); + llvm::FailureOr expectedRhsElementsPerThreadResult = + computeElementsPerThreadForOperand(rhsOperandNumber); + llvm::FailureOr expectedAccumulatorElementsPerThreadResult = + computeElementsPerThreadForOperand(accumulatorOperandNumber); + + if (llvm::failed(expectedLhsElementsPerThreadResult) || + llvm::failed(expectedRhsElementsPerThreadResult) || + llvm::failed(expectedAccumulatorElementsPerThreadResult)) { + errs << "MMA operation has no hardware constraints available"; + return mlir::failure(); + } + + unsigned expectedLhsElementsPerThread = *expectedLhsElementsPerThreadResult; + unsigned expectedRhsElementsPerThread = *expectedRhsElementsPerThreadResult; + unsigned expectedAccumulatorElementsPerThread = + *expectedAccumulatorElementsPerThreadResult; + + wave::ElementsPerThreadLatticeValue expectedLhs(expectedLhsElementsPerThread); + wave::ElementsPerThreadLatticeValue expectedRhs(expectedRhsElementsPerThread); + wave::ElementsPerThreadLatticeValue expectedAccumulator( + expectedAccumulatorElementsPerThread); + + // Propagate elements_per_thread to LHS operand using the helper function + llvm::MutableArrayRef lhsOnly = + operandElements.slice(lhsOperandNumber, 1); + + llvm::FailureOr lhsResult = + wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedLhs, llvm::ArrayRef(), + lhsOnly, "computed from MMA kind", "", "LHS operand", errs); + + if (llvm::failed(lhsResult)) { + return llvm::failure(); + } + + // Propagate elements_per_thread to RHS operand using the helper function + llvm::MutableArrayRef rhsOnly = + operandElements.slice(rhsOperandNumber, 1); + + llvm::FailureOr rhsResult = + wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedRhs, llvm::ArrayRef(), + rhsOnly, "computed from MMA kind", "", "RHS operand", errs); + + if (llvm::failed(rhsResult)) { + return mlir::failure(); + } + + // Propagate to the accumulator operand. + llvm::MutableArrayRef accumulatorOnly = + operandElements.slice(accumulatorOperandNumber, 1); + + llvm::FailureOr accumulatorResult = + wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedAccumulator, + llvm::ArrayRef(), + accumulatorOnly, "computed from MMA kind", "", "accumulator operand", + errs); + + if (llvm::failed(accumulatorResult)) { + return mlir::failure(); + } + + // Return Change if any operand changed + return (*lhsResult == mlir::ChangeResult::Change || + *rhsResult == mlir::ChangeResult::Change || + *accumulatorResult == mlir::ChangeResult::Change) + ? mlir::ChangeResult::Change + : mlir::ChangeResult::NoChange; +} + //----------------------------------------------------------------------------- // ReadOp //----------------------------------------------------------------------------- diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index 4022b6b82..f13900608 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -135,6 +135,162 @@ func.func @register_with_hyperparameter() attributes {hyperparameters = #wave.hy return } +#hw_constraint = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_interface +func.func @mma_elements_per_thread_interface() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values - elements_per_thread determined by MMA backward propagation. + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/64 threads = 16 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_32_threads = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 16}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_32_threads +func.func @mma_elements_per_thread_32_threads() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, + wave.constraints = [#hw_constraint_32_threads] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values - elements_per_thread determined by MMA backward propagation. + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_16x16x16_f16: 16*16/32 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_128_threads = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_128_threads +func.func @mma_elements_per_thread_128_threads() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint_128_threads] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values - elements_per_thread determined by MMA backward propagation. + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/128 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_interface_alt = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_interface_explicit +func.func @mma_elements_per_thread_interface_explicit() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint_interface_alt] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values with explicit elements_per_thread. + %lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init {elements_per_thread = 16} : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/64 threads = 16 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_32_threads_explicit = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 16}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_32_threads_explicit +func.func @mma_elements_per_thread_32_threads_explicit() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, + wave.constraints = [#hw_constraint_32_threads_explicit] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values with explicit elements_per_thread. + %lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_16x16x16_f16: 16*16/32 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +#hw_constraint_128_threads_explicit = #wave.hardware_constraint, + vector_shapes = {M = 1, N = 1, K = 8}, + max_bits_per_load = 128> + +// CHECK-LABEL: @mma_elements_per_thread_128_threads_explicit +func.func @mma_elements_per_thread_128_threads_explicit() attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>, + wave.constraints = [#hw_constraint_128_threads_explicit] +} { + %lhs_init = arith.constant 1.0 : f16 + %rhs_init = arith.constant 2.0 : f16 + %acc_init = arith.constant 0.0 : f32 + + // Create register values with explicit elements_per_thread. + %lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, > + %rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, > + %acc = wave.register %acc_init {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.mma {{.*}} {kind = #wave.mma_kind} + // f32_32x32x8_f16: 32*32/128 threads = 8 elements per thread. + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + // CHECK-LABEL: @allocate func.func @allocate() -> !wave.tensor<[@M, @N] of bf16, > { // CHECK: wave.allocate diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index d64db241f..d9ef3382b 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -147,6 +147,7 @@ func.func @unsupported_op() attributes {wave.hyperparameters = #wave.hyperparame } } + // ----- // CHECK: #wave.normal_form @@ -158,3 +159,85 @@ module { return } } + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { +func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // LHS without elements_per_thread - will be computed from RHS + MMA constraints. + %lhs_init = arith.constant 0.0 : f16 + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + + // RHS properly initialized through read operation. + %rhs = wave.read %mem1 {elements_per_thread = 8} : (!wave.tensor<[@N, @K] of f16, >) -> !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized through read operation. + %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // LHS elements_per_thread computed via MMA backward propagation + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { +func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // LHS properly initialized through read operation. + %lhs = wave.read %mem1 {elements_per_thread = 8} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > + + // RHS without elements_per_thread - will be computed from LHS + MMA constraints. + %rhs_init = arith.constant 0.0 : f16 + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized through read operation. + %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // RHS elements_per_thread computed via MMA backward propagation + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} +} + +// ----- + +// Test MMA can compute both LHS and RHS when both are uninitialized +module attributes {wave.normal_form = #wave.normal_form} { + func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@N, @K] of f16, >, %mem3: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // Both LHS and RHS without elements_per_thread - can compute from MMA formulas + %lhs_init = arith.constant 0.0 : f16 + %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > + %rhs_init = arith.constant 0.0 : f16 + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized through read operation. + %acc = wave.read %mem3 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // With proper MMA formulas, we can now compute both LHS and RHS from constraints, + // so this should succeed instead of failing + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return + } +} + +// ----- + +// Test MMA error when operand has wrong elements_per_thread +module attributes {wave.normal_form = #wave.normal_form} { + func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { + // LHS with wrong elements_per_thread (should be 8, not 4) + %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > + + // RHS without elements_per_thread - will be computed from MMA constraints. + %rhs_init = arith.constant 0.0 : f16 + %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > + + // ACC properly initialized + %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}} + %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return + } +}