Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def Exp2Op : UnaryWaveOp<"exp2"> {

def MmaOp : WaveOp<"mma",
[DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface,
["initializeIndexExprsForward", "initializeIndexExprsBackward"]>]>,
WaveArithmeticOpDoc {
Expand All @@ -130,6 +131,12 @@ def MmaOp : WaveOp<"mma",
"$lhs `,` $rhs `,` $accumulator " # commonArgumentsSyntax # "attr-dict `:`"
"functional-type(operands, results)";
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Compute the expected elements per thread for a specific operand of this MMA operation.
/// Returns failure if no hardware constraints are available.
llvm::FailureOr<unsigned> computeElementsPerThreadForOperand(unsigned operandIndex);
}];
}

//-----------------------------------------------------------------------------
Expand Down
155 changes: 155 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,161 @@ LogicalResult MmaOp::verify() {
accumulatorType.getElementType());
}

/// Compute the expected elements per thread for a specific MMA operand.
/// operandIndex: 0=LHS, 1=RHS, 2=Accumulator/Result
/// Returns failure if no constraints are found.
llvm::FailureOr<unsigned>
wave::MmaOp::computeElementsPerThreadForOperand(unsigned operandIndex) {
std::optional<wave::WaveMmaKind> mmaKind = getKind();
if (!mmaKind)
return mlir::failure();
wave::WaveMmaSpec spec =
wave::WaveMmaKindAttr::getSpec(getContext(), *mmaKind);

// Extract threads per wave from hardware constraint by walking up the
// ancestry.
mlir::Operation *op = getOperation();
while (op) {
if (auto constraints = op->getAttrOfType<mlir::ArrayAttr>(
wave::WaveDialect::kWaveConstraintsAttrName)) {
for (mlir::Attribute constraint : constraints) {
if (auto hardwareConstraint =
llvm::dyn_cast<wave::HardwareConstraintAttr>(constraint)) {
unsigned totalElements;
switch (operandIndex) {
case 0: // LHS: M x K
totalElements = spec.m * spec.k;
break;
case 1: // RHS: N x K
totalElements = spec.n * spec.k;
break;
case 2: // Accumulator/Result: M x N
totalElements = spec.m * spec.n;
break;
default:
return mlir::failure();
}
return totalElements / hardwareConstraint.getThreadsPerWave();
}
}
}
op = op->getParentOp();
}

// Return failure if no constraints found.
return mlir::failure();
}

llvm::FailureOr<mlir::ChangeResult>
wave::MmaOp::propagateElementsPerThreadForward(
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
llvm::raw_ostream &errs) {
llvm::FailureOr<unsigned> expectedElementsPerThreadResult =
computeElementsPerThreadForOperand(
getAccumulatorMutable().getOperandNumber());
if (llvm::failed(expectedElementsPerThreadResult)) {
errs << "MMA operation has no hardware constraints available";
return mlir::failure();
}
unsigned expectedElementsPerThread = *expectedElementsPerThreadResult;
wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread);
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
resultElements, "computed from MMA kind", "", "result", errs);
}

llvm::FailureOr<mlir::ChangeResult>
wave::MmaOp::propagateElementsPerThreadBackward(
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// For MMA, the accumulator should have the same elements per thread as the
// result. The LHS and RHS operands may have different constraints based on
// their dimensions.
// MMA operation always has exactly 3 operands: LHS, RHS, Accumulator
assert(operandElements.size() == 3 &&
"MMA operation must have exactly 3 operands");

unsigned lhsOperandNumber = getLhsMutable().getOperandNumber();
unsigned rhsOperandNumber = getRhsMutable().getOperandNumber();
unsigned accumulatorOperandNumber =
getAccumulatorMutable().getOperandNumber();

// Compute expected elements per thread for each operand
llvm::FailureOr<unsigned> expectedLhsElementsPerThreadResult =
computeElementsPerThreadForOperand(lhsOperandNumber);
llvm::FailureOr<unsigned> expectedRhsElementsPerThreadResult =
computeElementsPerThreadForOperand(rhsOperandNumber);
llvm::FailureOr<unsigned> expectedAccumulatorElementsPerThreadResult =
computeElementsPerThreadForOperand(accumulatorOperandNumber);

if (llvm::failed(expectedLhsElementsPerThreadResult) ||
llvm::failed(expectedRhsElementsPerThreadResult) ||
llvm::failed(expectedAccumulatorElementsPerThreadResult)) {
errs << "MMA operation has no hardware constraints available";
return mlir::failure();
}

unsigned expectedLhsElementsPerThread = *expectedLhsElementsPerThreadResult;
unsigned expectedRhsElementsPerThread = *expectedRhsElementsPerThreadResult;
unsigned expectedAccumulatorElementsPerThread =
*expectedAccumulatorElementsPerThreadResult;

wave::ElementsPerThreadLatticeValue expectedLhs(expectedLhsElementsPerThread);
wave::ElementsPerThreadLatticeValue expectedRhs(expectedRhsElementsPerThread);
wave::ElementsPerThreadLatticeValue expectedAccumulator(
expectedAccumulatorElementsPerThread);

// Propagate elements_per_thread to LHS operand using the helper function
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> lhsOnly =
operandElements.slice(lhsOperandNumber, 1);

llvm::FailureOr<mlir::ChangeResult> lhsResult =
wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedLhs, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
lhsOnly, "computed from MMA kind", "", "LHS operand", errs);

if (llvm::failed(lhsResult)) {
return llvm::failure();
}

// Propagate elements_per_thread to RHS operand using the helper function
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> rhsOnly =
operandElements.slice(rhsOperandNumber, 1);

llvm::FailureOr<mlir::ChangeResult> rhsResult =
wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedRhs, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
rhsOnly, "computed from MMA kind", "", "RHS operand", errs);

if (llvm::failed(rhsResult)) {
return mlir::failure();
}

// Propagate to the accumulator operand.
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> accumulatorOnly =
operandElements.slice(accumulatorOperandNumber, 1);

llvm::FailureOr<mlir::ChangeResult> accumulatorResult =
wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedAccumulator,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
accumulatorOnly, "computed from MMA kind", "", "accumulator operand",
errs);

if (llvm::failed(accumulatorResult)) {
return mlir::failure();
}

// Return Change if any operand changed
return (*lhsResult == mlir::ChangeResult::Change ||
*rhsResult == mlir::ChangeResult::Change ||
*accumulatorResult == mlir::ChangeResult::Change)
? mlir::ChangeResult::Change
: mlir::ChangeResult::NoChange;
}

//-----------------------------------------------------------------------------
// ReadOp
//-----------------------------------------------------------------------------
Expand Down
156 changes: 156 additions & 0 deletions water/test/Dialect/Wave/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,162 @@ func.func @register_with_hyperparameter() attributes {hyperparameters = #wave.hy
return
}

#hw_constraint = #wave.hardware_constraint<threads_per_wave = 64,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_32x32x8_f16>,
vector_shapes = {M = 1, N = 1, K = 8},
max_bits_per_load = 128>

// CHECK-LABEL: @mma_elements_per_thread_interface
func.func @mma_elements_per_thread_interface() attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>,
wave.constraints = [#hw_constraint]
} {
%lhs_init = arith.constant 1.0 : f16
%rhs_init = arith.constant 2.0 : f16
%acc_init = arith.constant 0.0 : f32

// Create register values - elements_per_thread determined by MMA backward propagation.
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, <register>>

// CHECK: wave.mma {{.*}} {kind = #wave.mma_kind<f32_32x32x8_f16>}
// f32_32x32x8_f16: 32*32/64 threads = 16 elements per thread.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_32x32x8_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

#hw_constraint_32_threads = #wave.hardware_constraint<threads_per_wave = 32,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_16x16x16_f16>,
vector_shapes = {M = 1, N = 1, K = 16},
max_bits_per_load = 128>

// CHECK-LABEL: @mma_elements_per_thread_32_threads
func.func @mma_elements_per_thread_32_threads() attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>,
wave.constraints = [#hw_constraint_32_threads]
} {
%lhs_init = arith.constant 1.0 : f16
%rhs_init = arith.constant 2.0 : f16
%acc_init = arith.constant 0.0 : f32

// Create register values - elements_per_thread determined by MMA backward propagation.
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, <register>>

// CHECK: wave.mma {{.*}} {kind = #wave.mma_kind<f32_16x16x16_f16>}
// f32_16x16x16_f16: 16*16/32 threads = 8 elements per thread.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

#hw_constraint_128_threads = #wave.hardware_constraint<threads_per_wave = 128,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_32x32x8_f16>,
vector_shapes = {M = 1, N = 1, K = 8},
max_bits_per_load = 128>

// CHECK-LABEL: @mma_elements_per_thread_128_threads
func.func @mma_elements_per_thread_128_threads() attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>,
wave.constraints = [#hw_constraint_128_threads]
} {
%lhs_init = arith.constant 1.0 : f16
%rhs_init = arith.constant 2.0 : f16
%acc_init = arith.constant 0.0 : f32

// Create register values - elements_per_thread determined by MMA backward propagation.
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init : !wave.tensor<[@M, @N] of f32, <register>>

// CHECK: wave.mma {{.*}} {kind = #wave.mma_kind<f32_32x32x8_f16>}
// f32_32x32x8_f16: 32*32/128 threads = 8 elements per thread.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_32x32x8_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

#hw_constraint_interface_alt = #wave.hardware_constraint<threads_per_wave = 64,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_32x32x8_f16>,
vector_shapes = {M = 1, N = 1, K = 8},
max_bits_per_load = 128>

// CHECK-LABEL: @mma_elements_per_thread_interface_explicit
func.func @mma_elements_per_thread_interface_explicit() attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>,
wave.constraints = [#hw_constraint_interface_alt]
} {
%lhs_init = arith.constant 1.0 : f16
%rhs_init = arith.constant 2.0 : f16
%acc_init = arith.constant 0.0 : f32

// Create register values with explicit elements_per_thread.
%lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init {elements_per_thread = 16} : !wave.tensor<[@M, @N] of f32, <register>>

// CHECK: wave.mma {{.*}} {kind = #wave.mma_kind<f32_32x32x8_f16>}
// f32_32x32x8_f16: 32*32/64 threads = 16 elements per thread.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_32x32x8_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

#hw_constraint_32_threads_explicit = #wave.hardware_constraint<threads_per_wave = 32,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_16x16x16_f16>,
vector_shapes = {M = 1, N = 1, K = 16},
max_bits_per_load = 128>

// CHECK-LABEL: @mma_elements_per_thread_32_threads_explicit
func.func @mma_elements_per_thread_32_threads_explicit() attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>,
wave.constraints = [#hw_constraint_32_threads_explicit]
} {
%lhs_init = arith.constant 1.0 : f16
%rhs_init = arith.constant 2.0 : f16
%acc_init = arith.constant 0.0 : f32

// Create register values with explicit elements_per_thread.
%lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, <register>>

// CHECK: wave.mma {{.*}} {kind = #wave.mma_kind<f32_16x16x16_f16>}
// f32_16x16x16_f16: 16*16/32 threads = 8 elements per thread.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

#hw_constraint_128_threads_explicit = #wave.hardware_constraint<threads_per_wave = 128,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_32x32x8_f16>,
vector_shapes = {M = 1, N = 1, K = 8},
max_bits_per_load = 128>

// CHECK-LABEL: @mma_elements_per_thread_128_threads_explicit
func.func @mma_elements_per_thread_128_threads_explicit() attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 32, N = 32, K = 8}>,
wave.constraints = [#hw_constraint_128_threads_explicit]
} {
%lhs_init = arith.constant 1.0 : f16
%rhs_init = arith.constant 2.0 : f16
%acc_init = arith.constant 0.0 : f32

// Create register values with explicit elements_per_thread.
%lhs = wave.register %lhs_init {elements_per_thread = 8} : !wave.tensor<[@M, @K] of f16, <register>>
%rhs = wave.register %rhs_init {elements_per_thread = 8} : !wave.tensor<[@N, @K] of f16, <register>>
%acc = wave.register %acc_init {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, <register>>

// CHECK: wave.mma {{.*}} {kind = #wave.mma_kind<f32_32x32x8_f16>}
// f32_32x32x8_f16: 32*32/128 threads = 8 elements per thread.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_32x32x8_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

// CHECK-LABEL: @allocate
func.func @allocate() -> !wave.tensor<[@M, @N] of bf16, <shared>> {
// CHECK: wave.allocate
Expand Down
Loading
Loading