@@ -163,18 +163,18 @@ module {
163163// -----
164164
165165module attributes {wave.normal_form = #wave.normal_form <full_types >} {
166- func.func @mma_uninitialized_lhs (%mem1: !wave.tensor <[@N , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
167- // LHS without elements_per_thread - this will remain uninitialized .
166+ func.func @mma_compute_lhs_from_rhs (%mem1: !wave.tensor <[@N , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
167+ // LHS without elements_per_thread - will be computed from RHS + MMA constraints .
168168 %lhs_init = arith.constant 0.0 : f16
169169 %lhs = wave.register %lhs_init : !wave.tensor <[@M , @K ] of f16 , <register >>
170170
171171 // RHS properly initialized through read operation.
172- %rhs = wave.read %mem1 {elements_per_thread = 4 } : (!wave.tensor <[@N , @K ] of f16 , <global >>) -> !wave.tensor <[@N , @K ] of f16 , <register >>
172+ %rhs = wave.read %mem1 {elements_per_thread = 8 } : (!wave.tensor <[@N , @K ] of f16 , <global >>) -> !wave.tensor <[@N , @K ] of f16 , <register >>
173173
174174 // ACC properly initialized through read operation.
175- %acc = wave.read %mem2 {elements_per_thread = 4 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
175+ %acc = wave.read %mem2 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
176176
177- // expected-error @below {{failed to propagate elements per thread backward: MMA operand #0 (LHS) has uninitialized elements_per_thread}}
177+ // LHS elements_per_thread computed via MMA backward propagation
178178 %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
179179 return
180180}
@@ -183,19 +183,61 @@ func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>,
183183// -----
184184
185185module attributes {wave.normal_form = #wave.normal_form <full_types >} {
186- func.func @mma_uninitialized_rhs (%mem1: !wave.tensor <[@M , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
186+ func.func @mma_compute_rhs_from_lhs (%mem1: !wave.tensor <[@M , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
187187 // LHS properly initialized through read operation.
188- %lhs = wave.read %mem1 {elements_per_thread = 4 } : (!wave.tensor <[@M , @K ] of f16 , <global >>) -> !wave.tensor <[@M , @K ] of f16 , <register >>
188+ %lhs = wave.read %mem1 {elements_per_thread = 8 } : (!wave.tensor <[@M , @K ] of f16 , <global >>) -> !wave.tensor <[@M , @K ] of f16 , <register >>
189189
190- // RHS without elements_per_thread - this will remain uninitialized .
190+ // RHS without elements_per_thread - will be computed from LHS + MMA constraints .
191191 %rhs_init = arith.constant 0.0 : f16
192192 %rhs = wave.register %rhs_init : !wave.tensor <[@N , @K ] of f16 , <register >>
193193
194194 // ACC properly initialized through read operation.
195- %acc = wave.read %mem2 {elements_per_thread = 4 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
195+ %acc = wave.read %mem2 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
196196
197- // expected-error @below {{failed to propagate elements per thread backward: MMA operand #1 (RHS) has uninitialized elements_per_thread}}
197+ // RHS elements_per_thread computed via MMA backward propagation
198198 %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
199199 return
200200}
201201}
202+
203+ // -----
204+
205+ // Test MMA can compute both LHS and RHS when both are uninitialized
206+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
207+ func.func @mma_compute_both_lhs_rhs (%mem1: !wave.tensor <[@M , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@N , @K ] of f16 , <global >>, %mem3: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
208+ // Both LHS and RHS without elements_per_thread - can compute from MMA formulas
209+ %lhs_init = arith.constant 0.0 : f16
210+ %lhs = wave.register %lhs_init : !wave.tensor <[@M , @K ] of f16 , <register >>
211+ %rhs_init = arith.constant 0.0 : f16
212+ %rhs = wave.register %rhs_init : !wave.tensor <[@N , @K ] of f16 , <register >>
213+
214+ // ACC properly initialized through read operation.
215+ %acc = wave.read %mem3 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
216+
217+ // With proper MMA formulas, we can now compute both LHS and RHS from constraints,
218+ // so this should succeed instead of failing
219+ %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
220+ return
221+ }
222+ }
223+
224+ // -----
225+
226+ // Test MMA error when operand has wrong elements_per_thread
227+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
228+ func.func @mma_operand_mismatch (%mem1: !wave.tensor <[@M , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
229+ // LHS with wrong elements_per_thread (should be 8, not 4)
230+ %lhs = wave.read %mem1 {elements_per_thread = 4 } : (!wave.tensor <[@M , @K ] of f16 , <global >>) -> !wave.tensor <[@M , @K ] of f16 , <register >>
231+
232+ // RHS without elements_per_thread - will be computed from MMA constraints.
233+ %rhs_init = arith.constant 0.0 : f16
234+ %rhs = wave.register %rhs_init : !wave.tensor <[@N , @K ] of f16 , <register >>
235+
236+ // ACC properly initialized
237+ %acc = wave.read %mem2 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
238+
239+ // expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}}
240+ %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
241+ return
242+ }
243+ }
0 commit comments