Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
c9246e3
pass stub
Hardcode84 Dec 6, 2025
fbb84ab
read/write lowering
Hardcode84 Dec 6, 2025
2f9856c
funcs
Hardcode84 Dec 6, 2025
df6e2e2
refac
Hardcode84 Dec 6, 2025
35f563b
cleanup
Hardcode84 Dec 7, 2025
86de819
indices support
Hardcode84 Dec 7, 2025
4f8b7cc
seq
Hardcode84 Dec 7, 2025
971c2f1
waitcnt pass wip
Hardcode84 Dec 7, 2025
015d87f
reqs
Hardcode84 Dec 7, 2025
1fdeec2
reqs
Hardcode84 Dec 7, 2025
83b50f5
state
Hardcode84 Dec 7, 2025
cecc76c
fix test
Hardcode84 Dec 7, 2025
bd348a5
fixes
Hardcode84 Dec 8, 2025
b7a0442
fix
Hardcode84 Dec 8, 2025
0b08cf3
rename cnt
Hardcode84 Dec 8, 2025
a6bd9eb
fix
Hardcode84 Dec 8, 2025
debaf1b
addr space
Hardcode84 Dec 8, 2025
3a9ea1d
mem deps
Hardcode84 Dec 8, 2025
eb89f51
AliasAnalysis
Hardcode84 Dec 8, 2025
1d879bb
test
Hardcode84 Dec 8, 2025
6c07612
update print
Hardcode84 Dec 8, 2025
2e6fd0a
ops lists
Hardcode84 Dec 8, 2025
ac85549
fixes
Hardcode84 Dec 8, 2025
b13b83a
refac
Hardcode84 Dec 8, 2025
3a94af6
LDS tests
Hardcode84 Dec 8, 2025
2460659
branch tests
Hardcode84 Dec 8, 2025
10cb6c3
lists bookkeeping
Hardcode84 Dec 8, 2025
079d444
pending ops set
Hardcode84 Dec 8, 2025
7d23460
explicit waits
Hardcode84 Dec 8, 2025
5c82620
tests
Hardcode84 Dec 8, 2025
39a2f8e
memref.load/store
Hardcode84 Dec 8, 2025
caa9980
copy WIP
Hardcode84 Dec 8, 2025
8a63357
copy handling
Hardcode84 Dec 8, 2025
fa7e489
memref.copy tests
Hardcode84 Dec 8, 2025
83f6ef1
gather-to-lds
Hardcode84 Dec 8, 2025
de069b9
double buffering test WIP
Hardcode84 Dec 9, 2025
b089085
tokens WIP
Hardcode84 Dec 9, 2025
84cb14b
printing
Hardcode84 Dec 9, 2025
351c89f
control flow WIP
Hardcode84 Dec 9, 2025
96ee333
tokens bookkeeping
Hardcode84 Dec 10, 2025
ebd3035
token stuff
Hardcode84 Dec 10, 2025
e4eb580
-aliasAnalysis
Hardcode84 Dec 10, 2025
9e09741
double-buffering test
Hardcode84 Dec 10, 2025
186b4b6
fix reqs and trips buffering
Hardcode84 Dec 10, 2025
0c42d8e
token cache
Hardcode84 Dec 10, 2025
3f9e8b6
fix witcount ops
Hardcode84 Dec 11, 2025
3f23651
more global ops tests
Hardcode84 Dec 11, 2025
35bc805
buffer ops
Hardcode84 Dec 11, 2025
662c8c3
buffer offsets
Hardcode84 Dec 11, 2025
f56a94d
ds ops
Hardcode84 Dec 11, 2025
39ecaa0
nicer
Hardcode84 Dec 11, 2025
351f78f
memref.load/store
Hardcode84 Dec 12, 2025
2bc4e39
8 and 16 bits
Hardcode84 Dec 12, 2025
0bbccc7
small bvalues fixes
Hardcode84 Dec 12, 2025
2b19267
revert 32 bit
Hardcode84 Dec 12, 2025
b8d5b8d
chipset options
Hardcode84 Dec 12, 2025
cbb6bdb
buffer ops fixes
Hardcode84 Dec 12, 2025
f7a3ab0
fix ds
Hardcode84 Dec 12, 2025
7d8738f
buffer stuff
Hardcode84 Dec 12, 2025
0944a7a
buffer fixes
Hardcode84 Dec 12, 2025
db89a40
doc
Hardcode84 Dec 13, 2025
7083919
reg2mem pass init
Hardcode84 Dec 13, 2025
f6bb568
doc
Hardcode84 Dec 13, 2025
17a1bf3
skip existing
Hardcode84 Dec 13, 2025
c5c09ff
block tests
Hardcode84 Dec 13, 2025
ef0abf9
vector op support
Hardcode84 Dec 13, 2025
bef1c6f
propagate views
Hardcode84 Dec 13, 2025
5423479
update desc
Hardcode84 Dec 13, 2025
d47022b
loop stuf
Hardcode84 Dec 14, 2025
4a4d13b
loop update
Hardcode84 Dec 14, 2025
97809d1
test
Hardcode84 Dec 14, 2025
b6bf5d3
refac
Hardcode84 Dec 14, 2025
af64522
add reg check
Hardcode84 Dec 14, 2025
b1eaf1f
prevent list grow
Hardcode84 Dec 14, 2025
a602979
register space handling in waitcnt insertion
Hardcode84 Dec 14, 2025
c54e97c
number registers
Hardcode84 Dec 14, 2025
a5a0c7e
make func pass
Hardcode84 Dec 14, 2025
3715f26
pass pipline fixes
Hardcode84 Dec 14, 2025
a3df2b8
rename regs
Hardcode84 Dec 14, 2025
89c683d
update water-opt
Hardcode84 Dec 14, 2025
28f3ee1
some lowering
Hardcode84 Dec 14, 2025
6bcb1fa
cleanup alloca
Hardcode84 Dec 14, 2025
c6b1316
copy to reg space lowering
Hardcode84 Dec 14, 2025
cf72eb2
reg lowering fixes
Hardcode84 Dec 14, 2025
744935f
bufer fixes
Hardcode84 Dec 15, 2025
ee528b8
fix shapes
Hardcode84 Dec 15, 2025
e315e7c
types fixes
Hardcode84 Dec 15, 2025
03916de
set amdgpu-num-vgpr
Hardcode84 Dec 15, 2025
b4e3582
reg lowering tests
Hardcode84 Dec 15, 2025
2cd0889
chipset
Hardcode84 Dec 15, 2025
0dedccd
rdna vs cdna tests
Hardcode84 Dec 15, 2025
9aa0cb6
lowering pipeline
Hardcode84 Dec 15, 2025
2f58cdb
fix < 32 stores
Hardcode84 Dec 16, 2025
cc8ca84
fix shared mem addresses
Hardcode84 Dec 16, 2025
2ce470d
refac
Hardcode84 Dec 16, 2025
e98f510
refac
Hardcode84 Dec 16, 2025
8345c86
fixes
Hardcode84 Dec 16, 2025
5d6d64a
update reg count
Hardcode84 Dec 16, 2025
047db50
barriers
Hardcode84 Dec 16, 2025
d2cfbe5
more barriers
Hardcode84 Dec 16, 2025
02e26f3
nicer code
Hardcode84 Dec 16, 2025
09d6dc1
code cleanup
Hardcode84 Dec 16, 2025
cafa81e
more code cleanup
Hardcode84 Dec 17, 2025
740ec70
fixes
Hardcode84 Dec 17, 2025
c45fc96
code clenaup
Hardcode84 Dec 17, 2025
666e989
refac
Hardcode84 Dec 18, 2025
eef3d64
include vgp range in comment
Hardcode84 Dec 18, 2025
7811c1c
fix waitcnt queue
Hardcode84 Dec 18, 2025
1aeb0ac
dumb mfma hazard mitigation
Hardcode84 Dec 18, 2025
82455a5
mxfp test
Hardcode84 Dec 19, 2025
f645249
test
Hardcode84 Dec 19, 2025
33dbe44
check
Hardcode84 Dec 19, 2025
219fbeb
test check
Hardcode84 Dec 19, 2025
9fb159f
align registers
Hardcode84 Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/kernel/wave/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from wave_lang.kernel.wave.utils.run_utils import get_default_arch
from pathlib import Path
from dataclasses import dataclass, field

require_e2e = pytest.mark.require_e2e
expensive_test = pytest.mark.expensive_test
Expand Down Expand Up @@ -91,3 +92,65 @@ def use_water_backend_bool(name: str):

def glob_asm_files(path: Path) -> list[Path]:
return list(filter(lambda x: x.suffix in [".s", ".rocmasm"], path.glob("*")))


@dataclass
class KernelMetadata:
"""Metadata extracted from kernel assembly."""

vgpr_count: int | None = None
vgpr_spill_count: int | None = None
sgpr_count: int | None = None
sgpr_spill_count: int | None = None
waitcnt_ops: list[str] = field(default_factory=list)


def extract_kernel_metadata(asm_text: str) -> KernelMetadata:
"""
Extract kernel metadata from ROCm assembly text.

Args:
asm_text: Assembly text content (e.g., from .rocmasm file)

Returns:
KernelMetadata containing:
- vgpr_count: Number of VGPRs allocated
- vgpr_spill_count: Number of VGPRs spilled
- sgpr_count: Number of SGPRs allocated
- sgpr_spill_count: Number of SGPRs spilled
- waitcnt_ops: List of all waitcnt operations found in the assembly
"""
import re

metadata = KernelMetadata()

# Extract from YAML metadata section (more reliable)
# Look for patterns like:
# .vgpr_count: 3
# .vgpr_spill_count: 0
# .sgpr_count: 8
# .sgpr_spill_count: 0

vgpr_count_match = re.search(r"\.vgpr_count:\s+(\d+)", asm_text)
if vgpr_count_match:
metadata.vgpr_count = int(vgpr_count_match.group(1))

vgpr_spill_match = re.search(r"\.vgpr_spill_count:\s+(\d+)", asm_text)
if vgpr_spill_match:
metadata.vgpr_spill_count = int(vgpr_spill_match.group(1))

sgpr_count_match = re.search(r"\.sgpr_count:\s+(\d+)", asm_text)
if sgpr_count_match:
metadata.sgpr_count = int(sgpr_count_match.group(1))

sgpr_spill_match = re.search(r"\.sgpr_spill_count:\s+(\d+)", asm_text)
if sgpr_spill_match:
metadata.sgpr_spill_count = int(sgpr_spill_match.group(1))

# Extract all waitcnt operations
# Pattern: s_waitcnt followed by any arguments
# Examples: s_waitcnt lgkmcnt(0), s_waitcnt vmcnt(0), etc.
waitcnt_pattern = re.compile(r"s_waitcnt\s+[^\n]+")
metadata.waitcnt_ops = waitcnt_pattern.findall(asm_text)

return metadata
71 changes: 71 additions & 0 deletions tests/kernel/wave/e2e/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ..common.utils import param_bool, require_e2e, use_water_backend_bool
from ._test_util import get_test_shapes
from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE


def get_copy_template(
Expand Down Expand Up @@ -132,3 +133,73 @@ def test_dynamic_copy(
b = device_zeros(shape, dtype=torch.float16)
test(a, b)
assert_close(a, b)


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_copy"))
@param_bool("use_buffer_ops", "buf_ops")
@use_water_backend_bool("use_water_backend")
@check_leaks
def test_copy_shared_memory(
shape: tuple[int, int],
use_buffer_ops: bool,
run_bench: bool,
use_water_backend: bool,
) -> None:
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Each workgroup works on single row of input data, and rows are further
# split into blocks of size up to 256. We have single wave per WG,
# and with default wave size of 64, each thread is operating on up to 4
# elements.
wave_size = 64
BLOCK_M = 1
# Tile size cannot be dynamic, so we use a fixed value here.
BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size)

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=wave_size,
vector_shapes={M: BLOCK_M, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
shared = tkw.allocate((M, N), (BLOCK_M, BLOCK_N), tkl.f16, SHARED_ADDRESS_SPACE)
res = tkw.read(a)
tkw.write(res, shared)
tkw.shared_memory_barrier()
res_shared = tkw.read(shared)
tkw.write(res_shared, b)

subs = {
M: shape[0],
N: shape[1],
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
}

options = WaveCompileOptions(
subs=subs,
canonicalize=True,
run_bench=run_bench,
use_buffer_ops=use_buffer_ops,
use_water_backend=use_water_backend,
minimize_shared_allocs=False, # TODO: minimize_shared_allocs=True is broken
)
options = set_default_run_config(options)
test = wave_compile(options, test)

a = device_randn(shape, dtype=torch.float16)
b = device_zeros(shape, dtype=torch.float16)
test(a, b)
assert_close(a, b)
192 changes: 167 additions & 25 deletions tests/kernel/wave/wave_gemm_mxfp_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import pytest
from pathlib import Path

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
Expand All @@ -25,7 +26,14 @@
ScaledMMAType,
)

from .common.utils import param_bool, require_e2e, require_cdna4
from .common.utils import (
extract_kernel_metadata,
glob_asm_files,
param_bool,
require_cdna4,
require_e2e,
use_water_backend_bool,
)

# Note this is specified by the HW and cannot be changed.
SCALE_GROUP_SIZE = 32
Expand Down Expand Up @@ -230,32 +238,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:


# BMK @ NK -> BMN represents Linear Layer style BMM.
@require_e2e
@require_cdna4
@pytest.mark.parametrize("batch", [4, 8])
@pytest.mark.parametrize(
"shape",
[(1024, 1024, 1024), (8192, 8192, 8192), (16384, 16384, 16384), (1, 16384, 1664)],
)
@pytest.mark.parametrize(
"mfma_variant",
[
ScaledMMAType.F32_16x16x128_F8F6F4,
],
)
@pytest.mark.parametrize(
"enable_scheduling",
[
SchedulingType.PREFETCH,
SchedulingType.FOUR_STAGE,
],
)
def testScaledBatchedGemmMXFP4(
batch: int,
shape: tuple[int],
def get_scaled_gemm_template(
shape: tuple[int, int, int],
mfma_variant: ScaledMMAType,
enable_scheduling: SchedulingType,
):
) -> tuple[WaveCompileOptions, "LaunchableWave"]:
# Input sizes
B = tkl.sym.B
M = tkl.sym.M
Expand Down Expand Up @@ -332,7 +319,42 @@ def repeat(
linearize_shared_access=True,
dynamic_symbols=dynamic_symbols,
)
return options, batched_gemm


@require_e2e
@require_cdna4
@pytest.mark.parametrize("batch", [4, 8])
@pytest.mark.parametrize(
"shape",
[(1024, 1024, 1024), (8192, 8192, 8192), (16384, 16384, 16384), (1, 16384, 1664)],
)
@pytest.mark.parametrize(
"mfma_variant",
[
ScaledMMAType.F32_16x16x128_F8F6F4,
],
)
@pytest.mark.parametrize(
"enable_scheduling",
[
SchedulingType.PREFETCH,
SchedulingType.FOUR_STAGE,
],
)
@use_water_backend_bool("use_water_backend")
def testScaledBatchedGemmMXFP4(
batch: int,
shape: tuple[int, int, int],
mfma_variant: ScaledMMAType,
enable_scheduling: SchedulingType,
use_water_backend: bool,
):
options, batched_gemm = get_scaled_gemm_template(
shape, mfma_variant, enable_scheduling
)
options = set_default_run_config(options)
options.use_water_backend = use_water_backend
batched_gemm = wave_compile(options, batched_gemm)

linearized_shape = (batch * shape[0], shape[1], shape[2])
Expand All @@ -351,6 +373,126 @@ def repeat(
torch.testing.assert_close(torch_out, out)


@use_water_backend_bool("use_water_backend")
def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
shape = (16384, 16384, 16384)
mfma_variant = ScaledMMAType.F32_16x16x128_F8F6F4
enable_scheduling = SchedulingType.PREFETCH
options, batched_gemm = get_scaled_gemm_template(
shape, mfma_variant, enable_scheduling
)
options.target = "gfx950"
options.minimize_shared_allocs = False
# options.use_global_to_shared = True
options.dump_intermediates = tmp_path
options.use_water_backend = use_water_backend
batched_gemm = wave_compile(options, batched_gemm)
asm_files = glob_asm_files(tmp_path)

assert len(asm_files) == 1, "Expected 1 ASM file"
text = asm_files[0].read_text()

metadata = extract_kernel_metadata(text)

# We encode the exact registers and wait counts count as we want to know if
# they suddenly change dur to backend or upstream MLIR changes.
if use_water_backend:
vgpr_count = 156
vgpr_spill_count = 0
sgpr_count = 45
sgpr_spill_count = 0
waitcounts = [
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(7)",
"s_waitcnt vmcnt(6)",
"s_waitcnt vmcnt(5)",
"s_waitcnt vmcnt(4)",
"s_waitcnt vmcnt(3)",
"s_waitcnt vmcnt(2)",
"s_waitcnt vmcnt(1)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(13)",
"s_waitcnt lgkmcnt(10)",
"s_waitcnt lgkmcnt(8)",
"s_waitcnt lgkmcnt(5)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(6)",
"s_waitcnt vmcnt(3)",
"s_waitcnt vmcnt(2)",
"s_waitcnt vmcnt(1)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(7)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(2)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
]
else:
vgpr_count = 160
vgpr_spill_count = 0
sgpr_count = 46
sgpr_spill_count = 0
waitcounts = [
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(7)",
"s_waitcnt vmcnt(6)",
"s_waitcnt vmcnt(5)",
"s_waitcnt vmcnt(4)",
"s_waitcnt vmcnt(3)",
"s_waitcnt vmcnt(2)",
"s_waitcnt vmcnt(1)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(2)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(6)",
"s_waitcnt vmcnt(3)",
"s_waitcnt vmcnt(2)",
"s_waitcnt vmcnt(1)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(7)",
"s_waitcnt lgkmcnt(5)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(2)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
]

assert (
metadata.vgpr_count == vgpr_count
), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}"
assert (
metadata.vgpr_spill_count == vgpr_spill_count
), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}"
assert (
metadata.sgpr_count == sgpr_count
), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}"
assert (
metadata.sgpr_spill_count == sgpr_spill_count
), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}"
assert (
metadata.waitcnt_ops == waitcounts
), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}"


@require_e2e
@require_cdna4
@pytest.mark.parametrize("shape", [(1024, 1024, 1024), (8192, 8192, 8192)])
Expand Down
Loading
Loading