Skip to content

Commit 1462694

Browse files
committed
Fix
Signed-off-by: tyb0807 <[email protected]>
1 parent 408b0da commit 1462694

File tree

3 files changed

+143
-55
lines changed

3 files changed

+143
-55
lines changed

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ def MmaOp : WaveOp<"mma",
133133
let hasVerifier = 1;
134134

135135
let extraClassDeclaration = [{
136-
/// Compute the expected elements per thread for this MMA operation.
137-
unsigned computeElementsPerThread();
136+
/// Compute the expected elements per thread for a specific operand of this MMA operation.
137+
/// Returns failure if no hardware constraints are available.
138+
llvm::FailureOr<unsigned> computeElementsPerThreadForOperand(unsigned operandIndex);
138139
}];
139140
}
140141

water/lib/Dialect/Wave/IR/WaveOps.cpp

Lines changed: 88 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,15 +1078,15 @@ LogicalResult MmaOp::verify() {
10781078
accumulatorType.getElementType());
10791079
}
10801080

1081-
/// Compute the expected elements per thread for this MMA operation.
1082-
/// Extracts threadsPerWave from ancestor operations with hardware constraints.
1083-
/// Returns 0 if no constraints are found.
1084-
unsigned wave::MmaOp::computeElementsPerThread() {
1085-
wave::WaveMmaKind kind = getKind();
1086-
if (!kind) {
1087-
return 0;
1081+
/// Compute the expected elements per thread for a specific MMA operand.
1082+
/// operandIndex: 0=LHS, 1=RHS, 2=Accumulator/Result
1083+
/// Returns failure if no constraints are found.
1084+
llvm::FailureOr<unsigned> wave::MmaOp::computeElementsPerThreadForOperand(unsigned operandIndex) {
1085+
if (!getKindAttr()) {
1086+
return mlir::failure();
10881087
}
1089-
wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec(getContext(), *kind);
1088+
wave::WaveMmaKind kind = getKind();
1089+
wave::WaveMmaSpec spec = wave::WaveMmaKindAttr::getSpec(getContext(), kind);
10901090

10911091
// Extract threads per wave from hardware constraint by walking up the
10921092
// ancestry.
@@ -1097,28 +1097,42 @@ unsigned wave::MmaOp::computeElementsPerThread() {
10971097
for (mlir::Attribute constraint : constraints) {
10981098
if (auto hardwareConstraint =
10991099
llvm::dyn_cast<wave::HardwareConstraintAttr>(constraint)) {
1100-
unsigned totalElements = spec.m * spec.n;
1100+
unsigned totalElements;
1101+
switch (operandIndex) {
1102+
case 0: // LHS: M x K
1103+
totalElements = spec.m * spec.k;
1104+
break;
1105+
case 1: // RHS: N x K
1106+
totalElements = spec.n * spec.k;
1107+
break;
1108+
case 2: // Accumulator/Result: M x N
1109+
totalElements = spec.m * spec.n;
1110+
break;
1111+
default:
1112+
return mlir::failure();
1113+
}
11011114
return totalElements / hardwareConstraint.getThreadsPerWave();
11021115
}
11031116
}
11041117
}
11051118
op = op->getParentOp();
11061119
}
11071120

1108-
// Return 0 to indicate failure if no constraints found.
1109-
return 0;
1121+
// Return failure if no constraints found.
1122+
return mlir::failure();
11101123
}
11111124

11121125
llvm::FailureOr<mlir::ChangeResult>
11131126
wave::MmaOp::propagateElementsPerThreadForward(
11141127
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
11151128
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
11161129
llvm::raw_ostream &errs) {
1117-
unsigned expectedElementsPerThread = computeElementsPerThread();
1118-
if (expectedElementsPerThread == 0) {
1130+
llvm::FailureOr<unsigned> expectedElementsPerThreadResult = computeElementsPerThreadForOperand(getAccumulatorMutable().getOperandNumber());
1131+
if (llvm::failed(expectedElementsPerThreadResult)) {
11191132
errs << "MMA operation has no hardware constraints available";
11201133
return mlir::failure();
11211134
}
1135+
unsigned expectedElementsPerThread = *expectedElementsPerThreadResult;
11221136
wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread);
11231137
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
11241138
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
@@ -1133,48 +1147,78 @@ wave::MmaOp::propagateElementsPerThreadBackward(
11331147
// For MMA, the accumulator should have the same elements per thread as the
11341148
// result. The LHS and RHS operands may have different constraints based on
11351149
// their dimensions.
1136-
unsigned expectedElementsPerThread = computeElementsPerThread();
1137-
if (expectedElementsPerThread == 0) {
1150+
// MMA operation always has exactly 3 operands: LHS, RHS, Accumulator
1151+
assert(operandElements.size() == 3 && "MMA operation must have exactly 3 operands");
1152+
1153+
unsigned lhsOperandNumber = getLhsMutable().getOperandNumber();
1154+
unsigned rhsOperandNumber = getRhsMutable().getOperandNumber();
1155+
unsigned accumulatorOperandNumber = getAccumulatorMutable().getOperandNumber();
1156+
1157+
// Compute expected elements per thread for each operand
1158+
llvm::FailureOr<unsigned> expectedLhsElementsPerThreadResult = computeElementsPerThreadForOperand(lhsOperandNumber);
1159+
llvm::FailureOr<unsigned> expectedRhsElementsPerThreadResult = computeElementsPerThreadForOperand(rhsOperandNumber);
1160+
llvm::FailureOr<unsigned> expectedAccumulatorElementsPerThreadResult = computeElementsPerThreadForOperand(accumulatorOperandNumber);
1161+
1162+
if (llvm::failed(expectedLhsElementsPerThreadResult) || llvm::failed(expectedRhsElementsPerThreadResult) || llvm::failed(expectedAccumulatorElementsPerThreadResult)) {
11381163
errs << "MMA operation has no hardware constraints available";
11391164
return mlir::failure();
11401165
}
1141-
wave::ElementsPerThreadLatticeValue expectedAccumulator(
1142-
expectedElementsPerThread);
11431166

1144-
unsigned accumulatorOperandNumber =
1145-
getAccumulatorMutable().getOperandNumber();
1167+
unsigned expectedLhsElementsPerThread = *expectedLhsElementsPerThreadResult;
1168+
unsigned expectedRhsElementsPerThread = *expectedRhsElementsPerThreadResult;
1169+
unsigned expectedAccumulatorElementsPerThread = *expectedAccumulatorElementsPerThreadResult;
11461170

1147-
// Validate that LHS and RHS operands have concrete elements_per_thread
1148-
// values. We don't propagate to them, but we check they've been properly
1149-
// initialized. During analysis initialization, bottom values are acceptable -
1150-
// we return NoChange to let the analysis continue rather than failing.
1151-
// LHS (0) and RHS (1) operands.
1152-
bool allLhsRhsInitialized = true;
1153-
for (unsigned i = 0; i < 2 && i < operandElements.size(); ++i) {
1154-
if (operandElements[i].isBottom()) {
1155-
allLhsRhsInitialized = false;
1156-
break;
1157-
}
1171+
wave::ElementsPerThreadLatticeValue expectedLhs(expectedLhsElementsPerThread);
1172+
wave::ElementsPerThreadLatticeValue expectedRhs(expectedRhsElementsPerThread);
1173+
wave::ElementsPerThreadLatticeValue expectedAccumulator(expectedAccumulatorElementsPerThread);
1174+
1175+
// Propagate elements_per_thread to LHS operand using the helper function
1176+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> lhsOnly =
1177+
operandElements.slice(lhsOperandNumber, 1);
1178+
1179+
llvm::FailureOr<mlir::ChangeResult> lhsResult =
1180+
wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1181+
expectedLhs,
1182+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), lhsOnly,
1183+
"computed from MMA kind", "", "LHS operand", errs);
1184+
1185+
if (llvm::failed(lhsResult)) {
1186+
return llvm::failure();
11581187
}
11591188

1160-
// If LHS/RHS operands are still at bottom, return NoChange to allow
1161-
// the analysis to continue. Forward propagation will initialize them.
1162-
if (!allLhsRhsInitialized) {
1163-
return mlir::ChangeResult::NoChange;
1189+
// Propagate elements_per_thread to RHS operand using the helper function
1190+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> rhsOnly =
1191+
operandElements.slice(rhsOperandNumber, 1);
1192+
1193+
llvm::FailureOr<mlir::ChangeResult> rhsResult =
1194+
wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1195+
expectedRhs,
1196+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), rhsOnly,
1197+
"computed from MMA kind", "", "RHS operand", errs);
1198+
1199+
if (llvm::failed(rhsResult)) {
1200+
return mlir::failure();
11641201
}
11651202

11661203
// Propagate to the accumulator operand.
1167-
if (operandElements.size() > accumulatorOperandNumber) {
1168-
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> accumulatorOnly =
1169-
operandElements.slice(accumulatorOperandNumber, 1);
1170-
1171-
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1172-
expectedAccumulator,
1173-
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), accumulatorOnly,
1174-
"computed from MMA kind", "", "accumulator operand", errs);
1204+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> accumulatorOnly =
1205+
operandElements.slice(accumulatorOperandNumber, 1);
1206+
1207+
llvm::FailureOr<mlir::ChangeResult> accumulatorResult =
1208+
wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1209+
expectedAccumulator,
1210+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), accumulatorOnly,
1211+
"computed from MMA kind", "", "accumulator operand", errs);
1212+
1213+
if (llvm::failed(accumulatorResult)) {
1214+
return mlir::failure();
11751215
}
11761216

1177-
return mlir::ChangeResult::NoChange;
1217+
// Return Change if any operand changed
1218+
return (*lhsResult == mlir::ChangeResult::Change ||
1219+
*rhsResult == mlir::ChangeResult::Change ||
1220+
*accumulatorResult == mlir::ChangeResult::Change) ?
1221+
mlir::ChangeResult::Change : mlir::ChangeResult::NoChange;
11781222
}
11791223

11801224
//-----------------------------------------------------------------------------
@@ -1355,6 +1399,7 @@ mlir::LogicalResult wave::RegisterOp::verify() {
13551399
return mlir::success();
13561400
}
13571401

1402+
13581403
//-----------------------------------------------------------------------------
13591404
// ExtractSliceOp
13601405
//-----------------------------------------------------------------------------

water/test/Dialect/Wave/propagate-elements-per-thread.mlir

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,18 @@ module {
163163
// -----
164164

165165
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
166-
func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
167-
// LHS without elements_per_thread - this will remain uninitialized.
166+
func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
167+
// LHS without elements_per_thread - will be computed from RHS + MMA constraints.
168168
%lhs_init = arith.constant 0.0 : f16
169169
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
170170

171171
// RHS properly initialized through read operation.
172-
%rhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@N, @K] of f16, <global>>) -> !wave.tensor<[@N, @K] of f16, <register>>
172+
%rhs = wave.read %mem1 {elements_per_thread = 8} : (!wave.tensor<[@N, @K] of f16, <global>>) -> !wave.tensor<[@N, @K] of f16, <register>>
173173

174174
// ACC properly initialized through read operation.
175-
%acc = wave.read %mem2 {elements_per_thread = 4} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
175+
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
176176

177-
// expected-error @below {{failed to propagate elements per thread backward: MMA operand #0 (LHS) has uninitialized elements_per_thread}}
177+
// LHS elements_per_thread computed via MMA backward propagation
178178
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
179179
return
180180
}
@@ -183,19 +183,61 @@ func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>,
183183
// -----
184184

185185
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
186-
func.func @mma_uninitialized_rhs(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
186+
func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
187187
// LHS properly initialized through read operation.
188-
%lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>
188+
%lhs = wave.read %mem1 {elements_per_thread = 8} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>
189189

190-
// RHS without elements_per_thread - this will remain uninitialized.
190+
// RHS without elements_per_thread - will be computed from LHS + MMA constraints.
191191
%rhs_init = arith.constant 0.0 : f16
192192
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
193193

194194
// ACC properly initialized through read operation.
195-
%acc = wave.read %mem2 {elements_per_thread = 4} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
195+
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
196196

197-
// expected-error @below {{failed to propagate elements per thread backward: MMA operand #1 (RHS) has uninitialized elements_per_thread}}
197+
// RHS elements_per_thread computed via MMA backward propagation
198198
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
199199
return
200200
}
201201
}
202+
203+
// -----
204+
205+
// Test MMA can compute both LHS and RHS when both are uninitialized
206+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
207+
func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@N, @K] of f16, <global>>, %mem3: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
208+
// Both LHS and RHS without elements_per_thread - can compute from MMA formulas
209+
%lhs_init = arith.constant 0.0 : f16
210+
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
211+
%rhs_init = arith.constant 0.0 : f16
212+
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
213+
214+
// ACC properly initialized through read operation.
215+
%acc = wave.read %mem3 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
216+
217+
// With proper MMA formulas, we can now compute both LHS and RHS from constraints,
218+
// so this should succeed instead of failing
219+
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
220+
return
221+
}
222+
}
223+
224+
// -----
225+
226+
// Test MMA error when operand has wrong elements_per_thread
227+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
228+
func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
229+
// LHS with wrong elements_per_thread (should be 8, not 4)
230+
%lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>
231+
232+
// RHS without elements_per_thread - will be computed from MMA constraints.
233+
%rhs_init = arith.constant 0.0 : f16
234+
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
235+
236+
// ACC properly initialized
237+
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
238+
239+
// expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}}
240+
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
241+
return
242+
}
243+
}

0 commit comments

Comments
 (0)