Skip to content

Commit 2015d7d

Browse files
committed
Enable Water-based index inference as optional secondary
This starts the transition to use Water-based index expression inference by enabling it to run in parallel with the original pyWave version of the same, optionally under a flag. Only the pyWave analyses or both may run. When both analyses run, the results of the two are compared with exceptions raised on mismatches. In any case, the results of the pyWave analyses are taken. After stabilization, appropraite extension and fixes, the following steps will be taken: 1. Turn the flag running the Water-based analysis to be on by default. 2. Make the Water-based analysis primary and keep the pyWave analysis as optional backup for comparison. 3. Deprecate and remove the pyWave analysis. Each of these steps will come in a separate commit after a certain period of time. The plumbing requires conversion from MLIR attributes to Wave constructs, mostly sympy expressions that is added to the extent required here. The logic in `emit_wave_dialect` is extended to collect attribute values for all operations in the IR based on a unique ID derived from Python hash of the object attached as attribute. Some portion of Wave code interacting with sympy is shifted to `wave_lang/support` to avoid the unconditional loading of IREE libraries due to `wave_lang.kernel` including `wave_lang.kernel.wave` which, in one or multiple places, transitively imports IREE, which in turn clashes with Water or any other MLIR-based project. Signed-off-by: Alex Zinenko <[email protected]>
1 parent 5256591 commit 2015d7d

21 files changed

+1600
-223
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# REQUIRES: water
2+
# RUN: python %s
3+
# The point of this test is to avoid crashing or asserting, so just run it under lit.
4+
5+
# Copyright 2025 The Wave Authors
6+
#
7+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
8+
# See https://llvm.org/LICENSE.txt for license information.
9+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
10+
11+
import wave_lang.kernel.lang as tkl
12+
import wave_lang.kernel.wave as tkw
13+
from wave_lang.kernel.wave.wave import LaunchableWave
14+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
15+
16+
from wave_lang.kernel.lang.global_symbols import *
17+
from wave_lang.kernel.wave.constraints import MMAType
18+
from wave_lang.kernel.wave.utils.general_utils import torch_dtype_to_wave
19+
20+
import torch
21+
22+
23+
# TODO: use the generic template, currently blocked by water not handling wave constraints.
24+
def _get_gemm_kernel(
25+
shape: tuple[int, int, int],
26+
mfma_variant: MMAType,
27+
dtype: torch.dtype = torch.float16,
28+
block_shape: tuple[int, int, int] | None = None,
29+
waves_per_block: tuple[int, int] | None = None,
30+
) -> tuple[LaunchableWave, dict[tkl.IndexSymbol, tkl.IndexExpr]]:
31+
if not block_shape:
32+
# BLOCK_M, BLOCK_N, BLOCK_K
33+
block_shape = (64, 64, 32)
34+
35+
if not waves_per_block:
36+
# WAVE_M, WAVE_N
37+
waves_per_block = (2, 2)
38+
39+
assert len(block_shape) == 3, "block_shape needs to be rank 3 for M, N, K."
40+
assert len(waves_per_block) == 2, "waves_per_block needs to be rank 2 for M, N."
41+
42+
# Input sizes
43+
M = tkl.sym.M
44+
N = tkl.sym.N
45+
K = tkl.sym.K
46+
# Workgroup tile sizes
47+
BLOCK_M = tkl.sym.BLOCK_M
48+
BLOCK_N = tkl.sym.BLOCK_N
49+
BLOCK_K = tkl.sym.BLOCK_K
50+
# Address space (for GPU, shared(1) or global(0))
51+
ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE
52+
dtype = torch_dtype_to_wave(dtype)
53+
# Expose user-constraints
54+
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
55+
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
56+
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
57+
58+
# TODO: dialect expects waves_per_block to be rank 3, so we append a 1 to the end.
59+
constraints += [
60+
tkw.HardwareConstraint(
61+
threads_per_wave=64,
62+
mma_type=mfma_variant,
63+
waves_per_block=waves_per_block + (1,),
64+
)
65+
]
66+
67+
# Wave-level micro-kernel.
68+
# Since warps are not directly addressable, there is no
69+
# explicit notion of a warp id (like a workgroup or thread id).
70+
# This kernel uses the input sizes M, N, K throughout, as the tiling
71+
# and data movement strategy is determined during the compilation process.
72+
# These can be influenced by introducing constraints.
73+
@tkw.wave(constraints)
74+
def gemm(
75+
a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype],
76+
b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype],
77+
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
78+
):
79+
c_reg = tkl.Register[M, N, tkl.f32](0.0)
80+
81+
# This microkernel encodes the fact that if the iterate
82+
# dimension were tiled, then we would need to materialize a loop.
83+
@tkw.iterate(K, init_args=[c_reg])
84+
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
85+
# a_reg: tkw.Register[M, K, dtype]
86+
a_reg = tkw.read(a)
87+
# b_reg: tkw.Register[N, K, dtype]
88+
b_reg = tkw.read(b)
89+
# acc: tkw.Register[M, N, tkl.f32]
90+
acc = tkw.mma(a_reg, b_reg, acc)
91+
return acc
92+
93+
# repeat represents the results of the loop
94+
tkw.write(repeat, c)
95+
96+
hyperparams = {
97+
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
98+
BLOCK_M: block_shape[0],
99+
BLOCK_N: block_shape[1],
100+
BLOCK_K: block_shape[2],
101+
M: shape[0],
102+
N: shape[1],
103+
K: shape[2],
104+
}
105+
return gemm, hyperparams
106+
107+
108+
def testGemm():
109+
gemm, hyperparams = _get_gemm_kernel(
110+
shape=(1024, 1024, 1024), mfma_variant=MMAType.F32_16x16x16_F16
111+
)
112+
options = WaveCompileOptions(
113+
subs=hyperparams,
114+
run_bench=False,
115+
check_water_analysis=True,
116+
)
117+
compiled_gemm = wave_compile(options, gemm)
118+
assert compiled_gemm is not None
119+
120+
121+
if __name__ == "__main__":
122+
testGemm()

lit_tests/kernel/wave/mlir_converter.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def failure_to_parse_override_mlir():
8080

8181
# Override the MLIR module after `wave_compile` so it doesn't attempt to parse it.
8282
options.override_mlir = "module {"
83-
_, diagnostics = emit_wave_dialect(trace, constraints, options)
83+
_, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
8484

8585
assert len(diagnostics) == 1
8686
# CHECK: Unable to parse module assembly
@@ -91,7 +91,9 @@ def failure_to_parse_override_mlir():
9191
@run_test
9292
def failure_to_parse_pipeline():
9393
trace, options, constraints = _get_dummy_trace_options_and_constraints()
94-
_, diagnostics = emit_wave_dialect(trace, constraints, options, pipeline="module {")
94+
_, diagnostics, _ = emit_wave_dialect(
95+
trace, constraints, options, pipeline="module {"
96+
)
9597

9698
assert len(diagnostics) == 1
9799
# CHECK: Failed to apply transform script: Unable to parse module assembly
@@ -102,7 +104,7 @@ def failure_to_parse_pipeline():
102104
@run_test
103105
def pipeline_is_empty():
104106
trace, options, constraints = _get_dummy_trace_options_and_constraints()
105-
_, diagnostics = emit_wave_dialect(
107+
_, diagnostics, _ = emit_wave_dialect(
106108
trace, constraints, options, pipeline="module {}"
107109
)
108110

@@ -115,7 +117,7 @@ def pipeline_is_empty():
115117
@run_test
116118
def pipeline_is_not_a_named_sequence():
117119
trace, options, constraints = _get_dummy_trace_options_and_constraints()
118-
_, diagnostics = emit_wave_dialect(
120+
_, diagnostics, _ = emit_wave_dialect(
119121
trace, constraints, options, pipeline="module { module {}}"
120122
)
121123

@@ -141,7 +143,7 @@ def pipeline_is_not_a_named_sequence():
141143
def failure_in_pipeline():
142144
trace, options, constraints = _get_dummy_trace_options_and_constraints()
143145
options.override_mlir = "module {}"
144-
_, diagnostics = emit_wave_dialect(
146+
_, diagnostics, _ = emit_wave_dialect(
145147
trace, constraints, options, pipeline=GUARANTEED_FAIL_TRANSFORM_SCRIPT
146148
)
147149
assert len(diagnostics) == 1
@@ -158,7 +160,7 @@ def override_mlir():
158160
module {
159161
func.func private @overridden_mlir()
160162
}"""
161-
emitted, diagnostics = emit_wave_dialect(trace, constraints, options)
163+
emitted, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
162164
assert len(diagnostics) == 0, "Did not expect errors in overridden IR."
163165

164166
# CHECK: func.func private @overridden_mlir()
@@ -218,7 +220,7 @@ def mlir_converter_matrix_add():
218220
constraints = matrix_add.constraints
219221

220222
# Use the mlir_converter to emit wave MLIR dialect
221-
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
223+
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
222224

223225
if diagnostics:
224226
for diagnostic in diagnostics:
@@ -374,7 +376,7 @@ def pipeline(root: OpHandle):
374376

375377
# Use the mlir_converter to emit wave MLIR dialect and apply the empty
376378
# pipeline.
377-
mlir_output, diagnostics = emit_wave_dialect(
379+
mlir_output, diagnostics, _ = emit_wave_dialect(
378380
trace, constraints, options, pipeline=pipeline_asm
379381
)
380382

@@ -528,7 +530,7 @@ def mixed_memory_kernel(
528530
constraints = mixed_memory_kernel.constraints
529531

530532
with Context(), Location.unknown():
531-
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
533+
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
532534

533535
assert len(diagnostics) == 0, f"Should have no diagnostics, got: {diagnostics}"
534536

@@ -582,7 +584,7 @@ def invalid_hyperparameter_kernel(
582584
# This should raise a RuntimeError due to invalid non-int hyperparameter
583585
try:
584586
with Context(), Location.unknown():
585-
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
587+
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
586588
assert False, "Expected RuntimeError for invalid non-int hyperparameter"
587589
except RuntimeError as e:
588590
# Verify the error message is what we expect

lit_tests/kernel/wave/mlir_converter_debug_locations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def mlir_converter_location():
9595
constraints = matrix_add.constraints
9696

9797
# Use the mlir_converter to emit wave MLIR dialect
98-
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
98+
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
9999

100100
if diagnostics:
101101
print(diagnostics)
@@ -210,7 +210,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
210210
constraints = matmul.constraints
211211

212212
# Use the mlir_converter to emit wave MLIR dialect
213-
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
213+
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
214214

215215
if diagnostics:
216216
print(diagnostics)

lit_tests/kernel/wave/mlir_converter_diagnostics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def mlir_converter_diagnostics_emission():
8585
constraints = matrix_add.constraints
8686

8787
# Use the mlir_converter to emit wave MLIR dialect
88-
_, diagnostics = emit_wave_dialect(
88+
_, diagnostics, _ = emit_wave_dialect(
8989
trace, constraints, options, test_diagnostic_emission=True
9090
)
9191

tests/kernel/wave/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def param_bool(name, shortname=None, values=None):
7070

7171

7272
def _is_water_and_ee_available() -> bool:
73-
from wave_lang.kernel.wave.water import is_water_available
73+
from wave_lang.support.detect_water import is_water_available
7474
from wave_lang.kernel.wave.execution_engine import is_execution_engine_available
7575

7676
return is_water_available() and is_execution_engine_available()

0 commit comments

Comments
 (0)