diff --git a/tests/kernel/wave/common/utils.py b/tests/kernel/wave/common/utils.py index 916e25344..744eaa518 100644 --- a/tests/kernel/wave/common/utils.py +++ b/tests/kernel/wave/common/utils.py @@ -7,6 +7,7 @@ import pytest from wave_lang.kernel.wave.utils.run_utils import get_default_arch from pathlib import Path +from dataclasses import dataclass, field require_e2e = pytest.mark.require_e2e expensive_test = pytest.mark.expensive_test @@ -91,3 +92,65 @@ def use_water_backend_bool(name: str): def glob_asm_files(path: Path) -> list[Path]: return list(filter(lambda x: x.suffix in [".s", ".rocmasm"], path.glob("*"))) + + +@dataclass +class KernelMetadata: + """Metadata extracted from kernel assembly.""" + + vgpr_count: int | None = None + vgpr_spill_count: int | None = None + sgpr_count: int | None = None + sgpr_spill_count: int | None = None + waitcnt_ops: list[str] = field(default_factory=list) + + +def extract_kernel_metadata(asm_text: str) -> KernelMetadata: + """ + Extract kernel metadata from ROCm assembly text. + + Args: + asm_text: Assembly text content (e.g., from .rocmasm file) + + Returns: + KernelMetadata containing: + - vgpr_count: Number of VGPRs allocated + - vgpr_spill_count: Number of VGPRs spilled + - sgpr_count: Number of SGPRs allocated + - sgpr_spill_count: Number of SGPRs spilled + - waitcnt_ops: List of all waitcnt operations found in the assembly + """ + import re + + metadata = KernelMetadata() + + # Extract from YAML metadata section (more reliable) + # Look for patterns like: + # .vgpr_count: 3 + # .vgpr_spill_count: 0 + # .sgpr_count: 8 + # .sgpr_spill_count: 0 + + vgpr_count_match = re.search(r"\.vgpr_count:\s+(\d+)", asm_text) + if vgpr_count_match: + metadata.vgpr_count = int(vgpr_count_match.group(1)) + + vgpr_spill_match = re.search(r"\.vgpr_spill_count:\s+(\d+)", asm_text) + if vgpr_spill_match: + metadata.vgpr_spill_count = int(vgpr_spill_match.group(1)) + + sgpr_count_match = re.search(r"\.sgpr_count:\s+(\d+)", asm_text) + if sgpr_count_match: + metadata.sgpr_count = int(sgpr_count_match.group(1)) + + sgpr_spill_match = re.search(r"\.sgpr_spill_count:\s+(\d+)", asm_text) + if sgpr_spill_match: + metadata.sgpr_spill_count = int(sgpr_spill_match.group(1)) + + # Extract all waitcnt operations + # Pattern: s_waitcnt followed by any arguments + # Examples: s_waitcnt lgkmcnt(0), s_waitcnt vmcnt(0), etc. + waitcnt_pattern = re.compile(r"s_waitcnt\s+[^\n]+") + metadata.waitcnt_ops = waitcnt_pattern.findall(asm_text) + + return metadata diff --git a/tests/kernel/wave/e2e/test_copy.py b/tests/kernel/wave/e2e/test_copy.py index 90b48caae..a3a4d4852 100644 --- a/tests/kernel/wave/e2e/test_copy.py +++ b/tests/kernel/wave/e2e/test_copy.py @@ -18,6 +18,7 @@ from ..common.utils import param_bool, require_e2e, use_water_backend_bool from ._test_util import get_test_shapes +from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE def get_copy_template( @@ -132,3 +133,73 @@ def test_dynamic_copy( b = device_zeros(shape, dtype=torch.float16) test(a, b) assert_close(a, b) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")) +@param_bool("use_buffer_ops", "buf_ops") +@use_water_backend_bool("use_water_backend") +@check_leaks +def test_copy_shared_memory( + shape: tuple[int, int], + use_buffer_ops: bool, + run_bench: bool, + use_water_backend: bool, +) -> None: + M = tkl.sym.M + N = tkl.sym.N + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + vector_shapes={M: BLOCK_M, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + shared = tkw.allocate((M, N), (BLOCK_M, BLOCK_N), tkl.f16, SHARED_ADDRESS_SPACE) + res = tkw.read(a) + tkw.write(res, shared) + tkw.shared_memory_barrier() + res_shared = tkw.read(shared) + tkw.write(res_shared, b) + + subs = { + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + } + + options = WaveCompileOptions( + subs=subs, + canonicalize=True, + run_bench=run_bench, + use_buffer_ops=use_buffer_ops, + use_water_backend=use_water_backend, + minimize_shared_allocs=False, # TODO: minimize_shared_allocs=True is broken + ) + options = set_default_run_config(options) + test = wave_compile(options, test) + + a = device_randn(shape, dtype=torch.float16) + b = device_zeros(shape, dtype=torch.float16) + test(a, b) + assert_close(a, b) diff --git a/tests/kernel/wave/wave_gemm_mxfp_test.py b/tests/kernel/wave/wave_gemm_mxfp_test.py index a06a79016..b2c4f477c 100644 --- a/tests/kernel/wave/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave/wave_gemm_mxfp_test.py @@ -1,5 +1,6 @@ import torch import pytest +from pathlib import Path import wave_lang.kernel.lang as tkl import wave_lang.kernel.wave as tkw @@ -25,7 +26,14 @@ ScaledMMAType, ) -from .common.utils import param_bool, require_e2e, require_cdna4 +from .common.utils import ( + extract_kernel_metadata, + glob_asm_files, + param_bool, + require_cdna4, + require_e2e, + use_water_backend_bool, +) # Note this is specified by the HW and cannot be changed. SCALE_GROUP_SIZE = 32 @@ -230,32 +238,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # BMK @ NK -> BMN represents Linear Layer style BMM. -@require_e2e -@require_cdna4 -@pytest.mark.parametrize("batch", [4, 8]) -@pytest.mark.parametrize( - "shape", - [(1024, 1024, 1024), (8192, 8192, 8192), (16384, 16384, 16384), (1, 16384, 1664)], -) -@pytest.mark.parametrize( - "mfma_variant", - [ - ScaledMMAType.F32_16x16x128_F8F6F4, - ], -) -@pytest.mark.parametrize( - "enable_scheduling", - [ - SchedulingType.PREFETCH, - SchedulingType.FOUR_STAGE, - ], -) -def testScaledBatchedGemmMXFP4( - batch: int, - shape: tuple[int], +def get_scaled_gemm_template( + shape: tuple[int, int, int], mfma_variant: ScaledMMAType, enable_scheduling: SchedulingType, -): +) -> tuple[WaveCompileOptions, "LaunchableWave"]: # Input sizes B = tkl.sym.B M = tkl.sym.M @@ -332,7 +319,42 @@ def repeat( linearize_shared_access=True, dynamic_symbols=dynamic_symbols, ) + return options, batched_gemm + + +@require_e2e +@require_cdna4 +@pytest.mark.parametrize("batch", [4, 8]) +@pytest.mark.parametrize( + "shape", + [(1024, 1024, 1024), (8192, 8192, 8192), (16384, 16384, 16384), (1, 16384, 1664)], +) +@pytest.mark.parametrize( + "mfma_variant", + [ + ScaledMMAType.F32_16x16x128_F8F6F4, + ], +) +@pytest.mark.parametrize( + "enable_scheduling", + [ + SchedulingType.PREFETCH, + SchedulingType.FOUR_STAGE, + ], +) +@use_water_backend_bool("use_water_backend") +def testScaledBatchedGemmMXFP4( + batch: int, + shape: tuple[int, int, int], + mfma_variant: ScaledMMAType, + enable_scheduling: SchedulingType, + use_water_backend: bool, +): + options, batched_gemm = get_scaled_gemm_template( + shape, mfma_variant, enable_scheduling + ) options = set_default_run_config(options) + options.use_water_backend = use_water_backend batched_gemm = wave_compile(options, batched_gemm) linearized_shape = (batch * shape[0], shape[1], shape[2]) @@ -351,6 +373,126 @@ def repeat( torch.testing.assert_close(torch_out, out) +@use_water_backend_bool("use_water_backend") +def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): + shape = (16384, 16384, 16384) + mfma_variant = ScaledMMAType.F32_16x16x128_F8F6F4 + enable_scheduling = SchedulingType.PREFETCH + options, batched_gemm = get_scaled_gemm_template( + shape, mfma_variant, enable_scheduling + ) + options.target = "gfx950" + options.minimize_shared_allocs = False + # options.use_global_to_shared = True + options.dump_intermediates = tmp_path + options.use_water_backend = use_water_backend + batched_gemm = wave_compile(options, batched_gemm) + asm_files = glob_asm_files(tmp_path) + + assert len(asm_files) == 1, "Expected 1 ASM file" + text = asm_files[0].read_text() + + metadata = extract_kernel_metadata(text) + + # We encode the exact registers and wait counts count as we want to know if + # they suddenly change dur to backend or upstream MLIR changes. + if use_water_backend: + vgpr_count = 156 + vgpr_spill_count = 0 + sgpr_count = 45 + sgpr_spill_count = 0 + waitcounts = [ + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(7)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(5)", + "s_waitcnt vmcnt(4)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(13)", + "s_waitcnt lgkmcnt(10)", + "s_waitcnt lgkmcnt(8)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + ] + else: + vgpr_count = 160 + vgpr_spill_count = 0 + sgpr_count = 46 + sgpr_spill_count = 0 + waitcounts = [ + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(7)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(5)", + "s_waitcnt vmcnt(4)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + ] + + assert ( + metadata.vgpr_count == vgpr_count + ), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}" + assert ( + metadata.vgpr_spill_count == vgpr_spill_count + ), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}" + assert ( + metadata.sgpr_count == sgpr_count + ), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}" + assert ( + metadata.sgpr_spill_count == sgpr_spill_count + ), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}" + assert ( + metadata.waitcnt_ops == waitcounts + ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" + + @require_e2e @require_cdna4 @pytest.mark.parametrize("shape", [(1024, 1024, 1024), (8192, 8192, 8192)]) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index a241b464c..0c43078b5 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -153,4 +153,80 @@ def WaterDropTransformOpsPass : Pass<"water-drop-transform-ops"> { }]; } +def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> { + let summary = "Insert wait instructions for asynchronous memory operations"; + let description = [{ + This pass analyzes asynchronous memory operations and inserts appropriate + wait/synchronization instructions to ensure memory operations complete + before their results are used. + + The pass tracks dependencies between memory operations and register uses, + maintaining scoreboards to determine when waits are necessary. It handles: + - Read-after-write (RAW) dependencies + - Write-after-write (WAW) dependencies + - Write-after-read (WAR) dependencies + + This is analogous to LLVM's SIInsertWaitcnts pass but operates at the + MLIR level for AMDGPU dialect operations. + }]; + let dependentDialects = [ + "::mlir::amdgpu::AMDGPUDialect", + ]; +} + +def WaterLowerMemoryOps : InterfacePass<"water-lower-memory-ops", "::mlir::FunctionOpInterface"> { + let summary = "Lower high-level memory operations to AMDGPU dialect"; + let description = [{ + This pass lowers high-level memory operations (vector.load, vector.store, + memref operations) to AMDGPU-specific memory operations (buffer loads/stores, + LDS operations, etc.). + + This lowering prepares the IR for subsequent waitcnt insertion and + final code generation. + }]; + let dependentDialects = [ + "::mlir::amdgpu::AMDGPUDialect", + "::mlir::gpu::GPUDialect", + "::mlir::LLVM::LLVMDialect", + "::mlir::memref::MemRefDialect", + "::mlir::ROCDL::ROCDLDialect", + "::mlir::vector::VectorDialect", + ]; + let options = [ + Option<"chipset", "chipset", "std::string", [{""}], + "Target chipset (e.g., gfx942, gfx1100)"> + ]; +} + +def WaterMaterializeRegCopy : Pass<"water-materialize-reg-copy"> { + let summary = "Materialize register copies for loads"; + let description = [{ + This pass materializes explicit register copies by transforming load + operations to route through a temporary buffer in the virtual register + memory space (memspace 128). For each load: + 1. Creates a subview of the source memref at the load indices + 2. Allocates a temporary buffer in memory space 128 (virtual register space) + 3. Copies from the subview to the temporary register buffer + 4. Loads from the temporary register buffer + + This transformation makes register traffic explicit in the IR, enabling + better analysis and optimization of register usage patterns. + }]; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::memref::MemRefDialect", + ]; +} + +def WaterNumberRegisters : InterfacePass<"water-number-registers", "::mlir::FunctionOpInterface"> { + let summary = "Assign physical registers to register space allocas"; + let description = [{ + This pass performs register allocation by assigning physical register numbers + to memref.alloca operations in memory space 128 (virtual register space). + }]; + let dependentDialects = [ + "::mlir::memref::MemRefDialect", + ]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index baef65d8f..89583aa30 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -7,6 +7,10 @@ add_mlir_dialect_library(MLIRWaterTransforms GPUModuleToBinary.cpp GPUToGPURuntime.cpp SLPVectorizer.cpp + WaterInsertWaitcnt.cpp + WaterLowerMemoryOps.cpp + WaterMaterializeRegCopy.cpp + WaterNumberRegisters.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/water @@ -15,6 +19,7 @@ add_mlir_dialect_library(MLIRWaterTransforms MLIRWaterPassesIncGen LINK_LIBS PUBLIC + MLIRAMDGPUDialect MLIRAnalysis MLIRArithDialect MLIRControlFlowDialect diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp new file mode 100644 index 000000000..8e9eb5fbb --- /dev/null +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -0,0 +1,870 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DenseAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "water-insert-waitcnt" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERINSERTWAITCNT +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { +static bool isBarrier(Operation *op) { + return isa(op) || isa(op); +} + +static bool isRegisterAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getInt() == 128; +} + +static bool isWorkgroupAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getValue() == gpu::AddressSpace::Workgroup; +} + +static bool isWorkgroupAddressSpace(std::optional value) { + if (!value) + return false; + + auto memrefType = cast(value->getType()); + return isWorkgroupAddressSpace(memrefType); +} + +static bool isGlobalAddressSpace(std::optional value) { + if (!value) + return false; + + auto memrefType = cast(value->getType()); + return !isWorkgroupAddressSpace(memrefType) && + !isRegisterAddressSpace(memrefType); +} + +/// Try to propagate view operations to the base memref. +static std::optional propagateViewOps(Value value) { + while (auto view = value.getDefiningOp()) + value = view.getViewSource(); + + return value; +} + +/// Check if the operation is a load operation and return the base memref. +static std::optional isLoadOp(Operation *op) { + // TODO: replace with the interface when available. + if (auto load = dyn_cast(op)) + return propagateViewOps(load.getBase()); + if (auto load = dyn_cast(op)) + return propagateViewOps(load.getMemRef()); + if (auto copy = dyn_cast(op)) + return propagateViewOps(copy.getSource()); + if (auto gather = dyn_cast(op)) + return propagateViewOps(gather.getSrc()); + + return std::nullopt; +} + +/// Check if the operation is a store operation and return the base memref. +static std::optional isStoreOp(Operation *op) { + // TODO: replace with the interface when available. + if (auto store = dyn_cast(op)) + return propagateViewOps(store.getBase()); + if (auto store = dyn_cast(op)) + return propagateViewOps(store.getMemRef()); + if (auto copy = dyn_cast(op)) + return propagateViewOps(copy.getTarget()); + if (auto gather = dyn_cast(op)) + return propagateViewOps(gather.getDst()); + + return std::nullopt; +} + +template +static raw_ostream &print_range(raw_ostream &os, T &&range) { + llvm::interleaveComma(range, os, [&](const auto &item) { os << item; }); + return os; +} + +/// Shared pending operations list for structural sharing +struct PendingOperations { + using TokenContainer = SmallVector; + + PendingOperations() = default; + PendingOperations(SmallVector &&ops, + SmallVector &&opsTokens) + : ops(std::move(ops)), opsTokens(std::move(opsTokens)) {} + + TokenContainer &addOp(Operation *op) { + // Failsafe to prevent infinite list growth. + if (size() >= 256) + llvm::report_fatal_error("Pending operations list is too long"); + + if (!ops.empty() && isBarrier(op) && isBarrier(ops.back())) + return opsTokens.back(); + + ops.push_back(op); + auto &back = opsTokens.emplace_back(); + if (auto memref = isStoreOp(op)) + back.push_back(*memref); + + if (auto memref = isLoadOp(op)) + back.push_back(*memref); + + return back; + } + + size_t size() const { return ops.size(); } + bool empty() const { return ops.empty(); } + + auto opsAndTokens() const { + assert(ops.size() == opsTokens.size() && + "ops and opsTokens must have the same size"); + return llvm::zip(ops, opsTokens); + } + + auto opsAndTokensReverse() const { + assert(ops.size() == opsTokens.size() && + "ops and opsTokens must have the same size"); + return llvm::zip(llvm::reverse(ops), llvm::reverse(opsTokens)); + } + + bool hasSameTail(const PendingOperations &other) const { + for (const auto &[op1, op2, tok1, tok2] : + llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops), + llvm::reverse(opsTokens), llvm::reverse(other.opsTokens))) { + if (op1 != op2) + return false; + if (tok1 != tok2) + return false; + } + return true; + } + + void updateTokens( + llvm::function_ref &)> updateFunc) { + for (TokenContainer &tokens : opsTokens) { + TokenContainer newTok; + for (Value tok : tokens) + updateFunc(tok, newTok); + + tokens = std::move(newTok); + } + } + + void print(raw_ostream &os) const { + os << "PendingOperations: ops=["; + llvm::interleaveComma(opsAndTokens(), os, [&](const auto &opAndTok) { + os << *std::get<0>(opAndTok) << "|"; + print_range(os, std::get<1>(opAndTok)); + }); + os << "]"; + } + + bool operator==(const PendingOperations &other) const { + return ops == other.ops && opsTokens == other.opsTokens; + } + + bool operator!=(const PendingOperations &other) const { + return !(*this == other); + } + + SmallVector ops; + SmallVector opsTokens; +}; + +/// Waitcnt requirement for synchronization +struct WaitcntRequirement { + std::optional load_cnt; + std::optional ds_cnt; + + WaitcntRequirement() = default; + + WaitcntRequirement(amdgpu::MemoryCounterWaitOp waitOp) { + if (auto loadCnt = waitOp.getLoadAttr()) + load_cnt = loadCnt.getInt(); + if (auto dsCnt = waitOp.getDsAttr()) + ds_cnt = dsCnt.getInt(); + } + + bool hasRequirement() const { + return load_cnt.has_value() || ds_cnt.has_value(); + } + + /// Merge with another requirement (take minimum for conservative join) + /// Returns true if this requirement changed + bool merge(const WaitcntRequirement &other) { + bool changed = false; + + // Take minimum of each counter (lower value = more restrictive) + if (other.load_cnt.has_value()) { + if (!load_cnt.has_value() || *other.load_cnt < *load_cnt) { + load_cnt = other.load_cnt; + changed = true; + } + } + if (other.ds_cnt.has_value()) { + if (!ds_cnt.has_value() || *other.ds_cnt < *ds_cnt) { + ds_cnt = other.ds_cnt; + changed = true; + } + } + + return changed; + } + + std::optional getLoadCnt() const { return load_cnt; } + std::optional getStoreCnt() const { return std::nullopt; } + std::optional getDsCnt() const { return ds_cnt; } + + bool isSameCounterType(const WaitcntRequirement &other) const { + return load_cnt.has_value() == other.load_cnt.has_value() || + ds_cnt.has_value() == other.ds_cnt.has_value(); + } + + static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { + WaitcntRequirement req; + std::optional loadBase = isLoadOp(op); + std::optional storeBase = isStoreOp(op); + if (isWorkgroupAddressSpace(loadBase) || + isWorkgroupAddressSpace(storeBase)) { + req.ds_cnt = zero ? 0 : 1; + } else if (isGlobalAddressSpace(loadBase) || + isGlobalAddressSpace(storeBase)) { + req.load_cnt = zero ? 0 : 1; + } + return req; + } + + WaitcntRequirement operator+(const WaitcntRequirement &other) const { + WaitcntRequirement result; + if (load_cnt || other.load_cnt) + result.load_cnt = load_cnt.value_or(0) + other.load_cnt.value_or(0); + if (ds_cnt || other.ds_cnt) + result.ds_cnt = ds_cnt.value_or(0) + other.ds_cnt.value_or(0); + return result; + } + + bool operator>(const WaitcntRequirement &other) const { + if (load_cnt && other.load_cnt && *load_cnt > *other.load_cnt) + return true; + if (ds_cnt && other.ds_cnt && *ds_cnt > *other.ds_cnt) + return true; + return false; + } + operator bool() const { return hasRequirement(); } + + void print(raw_ostream &os) const { + os << "WaitcntRequirement: load_cnt=" << load_cnt << " ds_cnt=" << ds_cnt; + } +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const WaitcntRequirement &result) { + result.print(os); + return os; +} + +static bool mayAlias(Value lhs, Value rhs, ArrayRef tokens) { + if (isWorkgroupAddressSpace(cast(lhs.getType())) != + isWorkgroupAddressSpace(cast(rhs.getType()))) + return false; + + return llvm::is_contained(tokens, lhs); +} + +/// Lattice state tracking pending asynchronous operations +class WaitcntState : public AbstractDenseLattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WaitcntState) + + using AbstractDenseLattice::AbstractDenseLattice; + + ChangeResult join(const AbstractDenseLattice &rhs) override { + const auto &rhsState = static_cast(rhs); + bool changed = false; + + SmallVector, 4> toAppend; + // Check if any pending operations has the same subset of operations as the + // rhs and take the longer one. + for (auto &rhsPendingOps : rhsState.pendingOpsLists) { + bool found = false; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->hasSameTail(*rhsPendingOps)) { + if (rhsPendingOps->size() > pendingOps->size()) { + pendingOps = rhsPendingOps; + changed = true; + } + found = true; + break; + } + } + if (!found) + toAppend.push_back(rhsPendingOps); + } + + // If there are any pending operations that don't have the same subset of + // operations as the rhs, append them to the pending operations lists. + if (!toAppend.empty()) { + pendingOpsLists.append(toAppend); + changed = true; + } + + if (changed) + resetPendingOpsSet(); + + // Merge requirements (take minimum for conservative join) + if (requirement.merge(rhsState.requirement)) + changed = true; + + return changed ? ChangeResult::Change : ChangeResult::NoChange; + } + + ChangeResult merge(const WaitcntState &rhs) { + bool changed = false; + + if (pendingOpsLists.size() != rhs.pendingOpsLists.size()) { + changed = true; + } else { + for (auto [listSrc, listDst] : + llvm::zip(pendingOpsLists, rhs.pendingOpsLists)) { + if (*listSrc != *listDst) { + changed = true; + break; + } + } + } + + if (changed) { + pendingOpsLists = rhs.pendingOpsLists; + resetPendingOpsSet(); + } + + if (requirement.merge(rhs.requirement)) + changed = true; + return changed ? ChangeResult::Change : ChangeResult::NoChange; + } + + void print(raw_ostream &os) const override { + os << "WaitcntState: pending ops ["; + for (auto &pendingOps : pendingOpsLists) { + os << "\n ["; + pendingOps->print(os); + os << "]"; + } + os << "\n ], requirement: " << requirement; + } + + void addPendingOp(Operation *op) { + if (pendingOpsLists.empty()) { + pendingOpsLists.push_back(std::make_shared()); + } else { + cow(); + } + for (auto &pendingOps : pendingOpsLists) { + auto &tokens = pendingOps->addOp(op); + for (Value token : tokens) + pendingOpsTokens.insert(token); + } + + pendingOpsSet.insert(op); + } + + /// Initialize to empty state + ChangeResult reset() { + if (pendingOpsLists.empty() && !requirement.hasRequirement()) + return ChangeResult::NoChange; + + pendingOpsLists.clear(); + requirement = {}; + resetPendingOpsSet(); + return ChangeResult::Change; + } + + /// Set the required waitcnt values + void setRequirement(const WaitcntRequirement &req) { + requirement = req; + for (auto &pendingOps : pendingOpsLists) { + SmallVector newPending; + SmallVector newPendingTokens; + WaitcntRequirement runningRequirement; + for (const auto &[op, tok] : llvm::reverse(pendingOps->opsAndTokens())) { + WaitcntRequirement opReq = + WaitcntRequirement::getOperationRequirement(op, false); + runningRequirement = runningRequirement + opReq; + if (runningRequirement > requirement) + continue; + + newPending.push_back(op); + newPendingTokens.push_back(tok); + } + if (newPending.size() == pendingOps->size()) + continue; + + std::reverse(newPending.begin(), newPending.end()); + std::reverse(newPendingTokens.begin(), newPendingTokens.end()); + pendingOps = std::make_shared( + std::move(newPending), std::move(newPendingTokens)); + } + + // Remove empty lists + pendingOpsLists.erase(std::remove_if(pendingOpsLists.begin(), + pendingOpsLists.end(), + [](const auto &pendingOps) { + return pendingOps->empty(); + }), + pendingOpsLists.end()); + + // Merge lists with the same tail (keep the longer one) + for (size_t i = 0; i < pendingOpsLists.size(); ++i) { + for (size_t j = i + 1; j < pendingOpsLists.size();) { + if (pendingOpsLists[i]->hasSameTail(*pendingOpsLists[j])) { + if (pendingOpsLists[j]->size() > pendingOpsLists[i]->size()) { + pendingOpsLists[i] = pendingOpsLists[j]; + } + pendingOpsLists.erase(pendingOpsLists.begin() + j); + } else { + ++j; + } + } + } + + resetPendingOpsSet(); + } + + void updateTokens( + llvm::function_ref &)> updateFunc) { + for (auto &pendingOps : pendingOpsLists) + pendingOps->updateTokens(updateFunc); + } + + void resetRequirement() { requirement = {}; } + + /// Get the required waitcnt values + const WaitcntRequirement &getRequirement() const { return requirement; } + + /// Check if there's a waitcnt requirement + bool hasRequirement() const { return requirement.hasRequirement(); } + + /// Check if a value depends on pending operations and compute required wait + WaitcntRequirement + checkSSADependency(Value val, + llvm::SmallSetVector &barriers) const { + // Check if val is produced by any pending operation + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return {}; + + if (!isPendingOp(defOp)) + return {}; + + WaitcntRequirement result; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) + continue; + + Operation *barrier = nullptr; + + // Search from the back to find the most recent dependency + bool found = false; + auto req = WaitcntRequirement::getOperationRequirement(defOp, true); + for (Operation *op : llvm::reverse(pendingOps->ops)) { + if (op == defOp) { + found = true; + break; + } + + if (!barrier && isBarrier(op)) + barrier = op; + + auto opReq = WaitcntRequirement::getOperationRequirement(op, false); + if (!req.isSameCounterType(opReq)) + continue; + + req = req + opReq; + } + + if (found) { + result.merge(req); + if (barrier) + barriers.insert(barrier); + } + } + + return result; + } + + /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait + WaitcntRequirement + checkMemoryDependency(Operation *op, + llvm::SmallSetVector &barriers) const { + auto checkMemref = [&](Value memref, bool isCurrentLoad, + bool isCurrentStore) -> WaitcntRequirement { + WaitcntRequirement result; + if (!isPendingOp(memref)) + return result; + + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) + continue; + + Operation *barrier = nullptr; + + // Search from the back to find the most recent dependency + for (const auto &[pendingOpVar, pendingTokensVar] : + pendingOps->opsAndTokensReverse()) { + + if (!barrier && isBarrier(pendingOpVar)) + barrier = pendingOpVar; + + // We canot capture structured bindings into lambda, thanks C++. + auto &pendingTokens = pendingTokensVar; + auto &pendingOp = pendingOpVar; + auto checkPendingMemref = + [&](Value pendingMemref, bool isPendingLoad, + bool isPendingStore) -> WaitcntRequirement { + WaitcntRequirement pendingResult; + if (!mayAlias(memref, pendingMemref, pendingTokens)) + return pendingResult; + + // Check for dependencies: + // RAW: current load after pending store + // WAR: current store after pending load + // WAW: current store after pending store + bool hasRAW = isCurrentLoad && isPendingStore; + bool hasWAR = isCurrentStore && isPendingLoad; + bool hasWAW = isCurrentStore && isPendingStore; + + if (hasRAW || hasWAR || hasWAW) { + // Found dependency - compute requirement by counting forward from + // here + auto it = llvm::find(pendingOps->ops, pendingOp); + auto req = + WaitcntRequirement::getOperationRequirement(pendingOp, true); + for (Operation *countOp : + llvm::make_range(std::next(it), pendingOps->ops.end())) { + auto opReq = + WaitcntRequirement::getOperationRequirement(countOp, false); + if (!req.isSameCounterType(opReq)) + continue; + req = req + opReq; + } + pendingResult.merge(req); + } + if (pendingResult.hasRequirement() && barrier) + barriers.insert(barrier); + + return pendingResult; + }; + if (auto loadBase = isLoadOp(pendingOp)) + result.merge(checkPendingMemref(*loadBase, true, false)); + if (auto storeBase = isStoreOp(pendingOp)) + result.merge(checkPendingMemref(*storeBase, false, true)); + } + } + + return result; + }; + // TODO: atomics will have both load and store flags set + WaitcntRequirement result; + if (auto loadBase = isLoadOp(op)) + result.merge(checkMemref(*loadBase, true, false)); + if (auto storeBase = isStoreOp(op)) + result.merge(checkMemref(*storeBase, false, true)); + return result; + } + +private: + /// Pending asynchronous operations + SmallVector, 4> pendingOpsLists; + + /// Required waitcnt after this state + WaitcntRequirement requirement; + + mutable llvm::SmallDenseSet pendingOpsSet; + mutable llvm::SmallDenseSet pendingOpsTokens; + + void cow() { + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps.use_count() > 1) { + auto newPending = std::make_shared(); + if (pendingOps) + *newPending = *pendingOps; + pendingOps = std::move(newPending); + } + } + } + + bool isPendingOp(llvm::PointerUnion opOrVal) const { + if (pendingOpsLists.empty()) + return false; + + // Build the set of pending operations lazily + bool found = false; + if (pendingOpsSet.empty()) { + assert(pendingOpsTokens.empty() && "pendingOpsTokens must be empty"); + Operation *op = dyn_cast(opOrVal); + Value val = dyn_cast(opOrVal); + for (const auto &pendingOps : pendingOpsLists) { + for (const auto &[pendingOp, pendingTokens] : + pendingOps->opsAndTokens()) { + if (pendingOp == op) + found = true; + + pendingOpsSet.insert(pendingOp); + for (Value token : pendingTokens) { + if (token == val) + found = true; + + pendingOpsTokens.insert(token); + } + } + } + } + + if (found) + return true; + + return isa(opOrVal) + ? pendingOpsSet.contains(cast(opOrVal)) + : pendingOpsTokens.contains(cast(opOrVal)); + } + + void resetPendingOpsSet() { + pendingOpsSet.clear(); + pendingOpsTokens.clear(); + } +}; + +static RegionSuccessor getRegionResults(ArrayRef successors, + Region *region) { + for (const auto &successor : successors) { + if (successor.getSuccessor() == region) + return successor; + } + llvm_unreachable("Region not found, malformed SCF op?"); +} + +/// Dense forward dataflow analysis for waitcnt insertion +class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { +public: + explicit WaitcntAnalysis(DataFlowSolver &solver) + : DenseForwardDataFlowAnalysis(solver) {} + + void setToEntryState(WaitcntState *lattice) override { + propagateIfChanged(lattice, lattice->reset()); + } + + LogicalResult visitOperation(Operation *op, const WaitcntState &before, + WaitcntState *after) override { + LDBG() << "Visiting: " << *op; + LDBG() << " Before: " << before; + + // Start with the state before this operation + WaitcntState newState = before; + + if (isBarrier(op)) { + LDBG() << " Barrier: " << *op; + newState.addPendingOp(op); + LDBG() << " New state: " << newState; + propagateIfChanged(after, after->join(newState)); + return success(); + } + + llvm::SmallSetVector barriers; + + // Check if any operands depend on pending operations (value dependency) + WaitcntRequirement opRequirement = after->getRequirement(); + for (Value operand : op->getOperands()) { + if (auto req = before.checkSSADependency(operand, barriers)) { + // Merge this requirement (take minimum for conservative wait) + opRequirement.merge(req); + } + } + + // Check for memory dependencies (RAW, WAR, WAW) + if (auto memReq = before.checkMemoryDependency(op, barriers)) { + LDBG() << " Memory dependency: " << memReq; + opRequirement.merge(memReq); + } else { + LDBG() << " No memory dependency"; + } + + if (opRequirement.hasRequirement() && !barriers.empty()) { + // newState.setRequirement(opRequirement); + LDBG() << " Barriers found, requirement: " << opRequirement; + for (Operation *barrier : barriers) { + LDBG() << " " << *barrier; + WaitcntState *beforeState = + getOrCreate(getProgramPointBefore(barrier)); + WaitcntState *afterState = + getOrCreate(getProgramPointAfter(barrier)); + WaitcntState newBarrierState = *beforeState; + newBarrierState.setRequirement(opRequirement); + propagateIfChanged(afterState, afterState->merge(newBarrierState)); + } + return success(); + } + + // Check if this is an existing memory_counter_wait operation + if (auto waitOp = dyn_cast(op)) { + LDBG() << " Existing waitcnt operation: " << *waitOp; + opRequirement.merge(WaitcntRequirement(waitOp)); + } + + // Set the requirement for this operation + if (opRequirement.hasRequirement()) { + newState.setRequirement(opRequirement); + LDBG() << " Operation requirement: " << opRequirement; + } else { + newState.resetRequirement(); + LDBG() << " No operation requirement"; + } + + // Check if this is an async memory operation (vector load/store) + if (WaitcntRequirement::getOperationRequirement(op, false) + .hasRequirement()) { + // Add this operation to the pending list + newState.addPendingOp(op); + } + + auto changed = after->merge(newState); + if (changed == ChangeResult::Change) { + LDBG() << " New state: " << newState; + } else { + LDBG() << " No change"; + } + propagateIfChanged(after, changed); + return success(); + } + + void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, + std::optional regionFrom, + std::optional regionTo, + const WaitcntState &before, + WaitcntState *after) override { + LDBG() << "Visiting region branch control flow transfer: " << *branch; + LDBG() << " Region from: " << regionFrom; + LDBG() << " Region to: " << regionTo; + LDBG() << " Before: " << before; + LDBG() << " After: " << *after; + + SmallVector successors; + branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + + auto destSuccessor = [&]() -> RegionSuccessor { + if (regionTo) { + Region ®ion = branch->getRegions()[*regionTo]; + return getRegionResults(successors, ®ion); + } else { + return getRegionResults(successors, nullptr); + } + }(); + // Dest values are either nested block args or branch op results. + ValueRange destValues = destSuccessor.getSuccessorInputs(); + + // Map from input values to dest values. + llvm::SmallDenseMap valuesMapping; + if (regionFrom) { + Region ®ion = branch->getRegions()[*regionFrom]; + for (Block &block : region) { + auto term = + dyn_cast(block.getTerminator()); + if (!term) + continue; + + ValueRange source = + term.getMutableSuccessorOperands(destSuccessor).getAsOperandRange(); + for (auto [source, dest] : llvm::zip(source, destValues)) + valuesMapping[source] = dest; + } + } else { + ValueRange source = branch.getEntrySuccessorOperands(destSuccessor); + for (auto [source, dest] : llvm::zip(source, destValues)) + valuesMapping[source] = dest; + } + + DominanceInfo dom; + + WaitcntState newState = before; + auto tokenUpdateFunc = [&](Value value, SmallVectorImpl &newTokens) { + // Keep the token if it dominates current op as user can use it directly. + if (dom.properlyDominates(value, branch)) + newTokens.push_back(value); + + // Add token propagated through region control flow. + if (Value mappedValue = valuesMapping.lookup(value)) + if (newTokens.empty() || newTokens.back() != mappedValue) + newTokens.push_back(mappedValue); + }; + newState.updateTokens(tokenUpdateFunc); + + LDBG() << " New state: " << newState; + + propagateIfChanged(after, after->join(newState)); + } +}; + +/// Pass that inserts wait/synchronization instructions for asynchronous +/// memory operations. This is analogous to LLVM's SIInsertWaitcnts pass. +class WaterInsertWaitcntPass + : public water::impl::WaterInsertWaitcntBase { +public: + void runOnOperation() override { + LDBG() << "Running WaterInsertWaitcntPass"; + Operation *op = getOperation(); + + DataFlowSolver solver; + loadBaselineAnalyses(solver); + solver.load(); + + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // Insert waitcnt operations based on analysis results + IRRewriter rewriter(&getContext()); + op->walk([&](Operation *operation) { + const WaitcntState *state = solver.lookupState( + solver.getProgramPointAfter(operation)); + if (!state || !state->hasRequirement()) + return; + + const WaitcntRequirement &req = state->getRequirement(); + + auto getAttr = [&](std::optional cnt) -> IntegerAttr { + if (!cnt.has_value()) + return nullptr; + return rewriter.getI32IntegerAttr(*cnt); + }; + + // Insert wait operation before the current operation. + // If the current operation is already a memory_counter_wait operation + // they will be merged later. + rewriter.setInsertionPoint(operation); + amdgpu::MemoryCounterWaitOp::create( + rewriter, operation->getLoc(), getAttr(req.getLoadCnt()), + getAttr(req.getStoreCnt()), getAttr(req.getDsCnt()), nullptr, + nullptr); + }); + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp new file mode 100644 index 000000000..3fd4e6ad6 --- /dev/null +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -0,0 +1,1121 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERLOWERMEMORYOPS +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +static unsigned getBitwidth(ShapedType type) { + assert(type.hasStaticShape() && "Shaped type must have static shape"); + return type.getNumElements() * type.getElementTypeBitWidth(); +} + +static unsigned getBitwidth(Type type) { + if (auto shaped = dyn_cast(type)) + return getBitwidth(shaped); + + return type.getIntOrFloatBitWidth(); +} + +static std::string getVGPRRange(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount) { + assert(vgprCount > 0 && "VGPR count must be greater than 0"); + unsigned start = vgprOffset + vgprNum; + if (vgprCount == 1) { + return ("v" + llvm::Twine(start)).str(); + } else { + unsigned end = start + vgprCount - 1; + return ("v[" + llvm::Twine(start) + ":" + llvm::Twine(end) + "]").str(); + } +} + +static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, bool isOutput) { + return (llvm::Twine(isOutput ? "=" : "") + "{" + + getVGPRRange(vgprOffset, vgprNum, vgprCount) + "}") + .str(); +} + +static FailureOr getLoadSizeSuffixRDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("u8"); + case 16: + return StringRef("u16"); + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } +} + +static FailureOr getStoreSizeSuffixRDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } +} + +static FailureOr getLoadSizeSuffixCDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("ubyte"); + case 16: + return StringRef("ushort"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +static FailureOr getStoreSizeSuffixCDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("byte"); + case 16: + return StringRef("short"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +static FailureOr getBufferLoadSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getLoadSizeSuffixRDNA(bitWidth); + } else { + return getLoadSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getBufferStoreSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getStoreSizeSuffixRDNA(bitWidth); + } else { + return getStoreSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getGlobalLoadSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getLoadSizeSuffixRDNA(bitWidth); + } else { + return getLoadSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getGlobalStoreSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getStoreSizeSuffixRDNA(bitWidth); + } else { + return getStoreSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getDSLoadSuffix(unsigned bitWidth, + bool /*isRDNAArch*/) { + return getLoadSizeSuffixRDNA(bitWidth); +} + +static FailureOr getDSStoreSuffix(unsigned bitWidth, + bool /*isRDNAArch*/) { + return getStoreSizeSuffixRDNA(bitWidth); +} + +/// Create an LLVM inline assembly operation with standard attributes +static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, + TypeRange resultTypes, + ValueRange operands, StringRef asmStr, + StringRef constraints, + bool hasSideEffects) { + return LLVM::InlineAsmOp::create( + rewriter, loc, resultTypes, operands, asmStr, constraints, hasSideEffects, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); +} + +/// Detect if chipset is RDNA vs CDNA architecture +static bool isRDNA(const amdgpu::Chipset &chipset) { + return chipset.majorVersion != 9; +} + +static Operation *propagateExtract(Operation *op) { + if (auto extract = dyn_cast(op)) + return extract.getSource().getDefiningOp(); + if (auto extract = dyn_cast(op)) + return extract.getSource().getDefiningOp(); + return nullptr; +} + +static unsigned checkHazards(Operation *currentOp, Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return 0; + + while (auto nextOp = propagateExtract(op)) + op = nextOp; + + if (op->getBlock() != currentOp->getBlock()) + return 0; + + if (!isa(op)) + return 0; + + while (op != currentOp) { + if (isa(op) && + cast(op).getIntrin() == "llvm.amdgcn.s.nop") + return 0; + op = op->getNextNode(); + } + + return 5; // HACK for now +} + +static void handleHazards(IRRewriter &rewriter, Location loc, Operation *op, + Value value) { + unsigned hazard = checkHazards(op, value); + if (hazard > 0) { + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + Value nopCount = + arith::ConstantIntOp::create(rewriter, loc, hazard - 1, 16); + StringAttr intrin = rewriter.getStringAttr("llvm.amdgcn.s.nop"); + LLVM::CallIntrinsicOp::create(rewriter, loc, {}, intrin, nopCount); + } +} + +/// Compute byte offset as iX for a memref access with indices +template +static Value computeMemrefByteOffset(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + // Extract strided metadata to get offset and strides + auto metadataOp = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + Value offset = metadataOp.getOffset(); + + // Compute linear index from multidimensional indices + Value linearIndex = offset; + for (auto i : llvm::seq(0, indices.size())) { + Value stride = metadataOp.getStrides()[i]; + Value indexTimesStride = arith::MulIOp::create( + rewriter, loc, indices[i], stride, arith::IntegerOverflowFlags::nsw); + linearIndex = + arith::AddIOp::create(rewriter, loc, linearIndex, indexTimesStride, + arith::IntegerOverflowFlags::nsw); + } + + // Convert linear index to byte offset + unsigned elementBytes = elementBitWidth / 8; + Value elementSize = + arith::ConstantIndexOp::create(rewriter, loc, elementBytes); + Value byteOffset = + arith::MulIOp::create(rewriter, loc, linearIndex, elementSize, + arith::IntegerOverflowFlags::nsw); + + Type indexType = IntegerType::get(rewriter.getContext(), Bits); + return arith::IndexCastOp::create(rewriter, loc, indexType, byteOffset); +} + +/// Compute the final address for a memref access with indices (for global +/// operations) +template +static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), MemSpace); + auto intType = rewriter.getIntegerType(Bits); + + // Extract base pointer + auto metadataOp = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + Value basePtr = metadataOp.getBaseBuffer(); + + // Convert base pointer to i64 + Value basePtrInt = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); + basePtrInt = arith::IndexCastOp::create(rewriter, loc, intType, basePtrInt); + + // Compute byte offset + Value byteOffsetI64 = computeMemrefByteOffset(rewriter, loc, memref, + indices, elementBitWidth); + + // Add byte offset to base pointer + Value finalAddr = + arith::AddIOp::create(rewriter, loc, basePtrInt, byteOffsetI64, + arith::IntegerOverflowFlags::nsw); + return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); +} + +/// Extract buffer descriptor and base offset from a fat_raw_buffer memref +/// addrspace(7) format: {<4 x i32> rsrc, i32 offset} (160 bits total) +/// Returns: {resource descriptor (i128), base offset (i32)} +static std::pair +extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { + // Create proper memref descriptor struct type: {ptr, ptr, offset, + // sizes[rank], strides[rank]} + auto memrefType = cast(memref.getType()); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto arrayType = LLVM::LLVMArrayType::get(i64Type, memrefType.getRank()); + Type descriptorFields[] = {ptrType, ptrType, i64Type, arrayType, arrayType}; + + auto memrefDescType = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), descriptorFields); + + Value memrefDescVal = + UnrealizedConversionCastOp::create(rewriter, loc, memrefDescType, memref) + .getResult(0); + + MemRefDescriptor memrefDesc(memrefDescVal); + Value bufferPtr = memrefDesc.alignedPtr(rewriter, loc); + + // Convert to i160 to access full buffer descriptor {<4 x i32> rsrc, i32 + // offset} + auto i160Type = IntegerType::get(rewriter.getContext(), 160); + Value fullDesc = LLVM::PtrToIntOp::create(rewriter, loc, i160Type, bufferPtr); + + // Extract lower 32 bits for base offset + Value baseOffset = arith::TruncIOp::create(rewriter, loc, i32Type, fullDesc); + + // Extract upper 128 bits for resource descriptor + auto c32 = arith::ConstantIntOp::create(rewriter, loc, i160Type, 32); + Value rsrcBits160 = arith::ShRUIOp::create(rewriter, loc, fullDesc, c32); + auto i128Type = IntegerType::get(rewriter.getContext(), 128); + Value rsrcBits = + arith::TruncIOp::create(rewriter, loc, i128Type, rsrcBits160); + + return {rsrcBits, baseOffset}; +} + +/// Helper to get memref, result type, and bit width from load operation +template +static std::tuple getLoadOpInfo(LoadOpTy loadOp) { + if constexpr (std::is_same_v) { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = getBitwidth(vectorType); + return {loadOp.getBase(), vectorType, bitWidth}; + } else { + auto elementType = loadOp.getResult().getType(); + unsigned bitWidth = getBitwidth(elementType); + return {loadOp.getMemRef(), elementType, bitWidth}; + } +} + +/// Helper to get memref, value type, and bit width from store operation +template +static std::tuple getStoreOpInfo(StoreOpTy storeOp) { + if constexpr (std::is_same_v) { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = getBitwidth(vectorType); + return {storeOp.getBase(), vectorType, bitWidth}; + } else { + auto elementType = storeOp.getValueToStore().getType(); + unsigned bitWidth = getBitwidth(elementType); + return {storeOp.getMemRef(), elementType, bitWidth}; + } +} + +/// Lower vector/scalar load to AMDGPU buffer load inline assembly +template +static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + // TODO: for bitwidths less than 32, we will need to truncate the value to 32 + // immediately after the load, breaking the calculated dependencies. + // For now, just let llvm handle the loading + if (bitWidth < 32) + return success(); + + FailureOr suffix = getBufferLoadSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); + + // Constraints: "=v" for output (VGPR), "v" for offset (VGPR), "s" for + // descriptor (SGPR[4]) + StringRef constraints = "=v,v,s"; + + // Compute byte offset from indices + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefByteOffset<32>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Extract buffer descriptor and base offset from memref + auto [bufferDesc, baseOffset] = + extractBufferDescriptor(rewriter, loc, memref); + + // Add base offset to computed offset + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + // Create inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, + ValueRange{finalOffset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +/// Lower vector/scalar load to LLVM inline assembly (global_load_*) +template +static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + if (bitWidth < 32) + return success(); + + FailureOr suffix = getGlobalLoadSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return loadOp.emitError("unsupported load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + + // Build the inline assembly string: "global_load_b64 $0, $1, off" + std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); + + // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) + StringRef constraints = "=v,v"; + + rewriter.setInsertionPoint(loadOp); + + // Compute the final address + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value addr = computeMemrefAddress<64, 0>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Create the inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +/// Lower vector/scalar load to AMDGPU DS load inline assembly +template +static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + if (bitWidth < 32) + return success(); + + FailureOr suffix = getDSLoadSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + // Constraints: "=v" for output (VGPR), "v" for address (VGPR) + StringRef constraints = "=v,v"; + + // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefAddress<32, 3>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Create inline assembly operation (DS operations use 32-bit addresses) + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +static Value extendToReg(Value value, IRRewriter &rewriter, Location loc) { + unsigned bitWidth = getBitwidth(value.getType()); + if (bitWidth >= 32) { + Type intType = rewriter.getIntegerType(bitWidth); + if (value.getType() != intType) + value = LLVM::BitcastOp::create(rewriter, loc, intType, value); + return value; + } + + // Sched barrier to prevent moving the expansion before the waitcnt. + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + + Type intType = rewriter.getIntegerType(bitWidth); + if (value.getType() != intType) + value = LLVM::BitcastOp::create(rewriter, loc, intType, value); + + return arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), value); +} + +/// Lower vector/scalar store to AMDGPU buffer store inline assembly +template +static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + + FailureOr suffix = getBufferStoreSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return storeOp.emitError("unsupported buffer store bit width: ") + << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); + + // Build inline assembly: "buffer_store_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_store_" + *suffix + " $0, $1, $2, 0 offen").str(); + + // Constraints: "v" for data (VGPR), "v" for offset (VGPR), "s" for descriptor + // (SGPR[4]) + StringRef constraints = "v,v,s"; + + // Compute byte offset from indices + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefByteOffset<32>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + + // Extract buffer descriptor and base offset from memref + auto [bufferDesc, baseOffset] = + extractBufferDescriptor(rewriter, loc, memref); + + // Add base offset to computed offset + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); + + // Create inline assembly operation (no result for store) + createInlineAsm(rewriter, loc, TypeRange{}, + {valueToStore, finalOffset, bufferDesc}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +/// Lower vector/scalar store to LLVM inline assembly (global_store_*) +template +static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + + FailureOr suffix = getGlobalStoreSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return storeOp.emitError("unsupported store bit width: ") << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); + + // Build the inline assembly string: "global_store_b64 $0, $1, off" + std::string asmStr = ("global_store_" + *suffix + " $0, $1, off").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + StringRef constraints = "v,v"; + + // Compute the final address + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value addr = computeMemrefAddress<64, 0>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); + + // Create the inline assembly operation (no result for store) + createInlineAsm(rewriter, loc, {}, {addr, valueToStore}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +/// Lower vector/scalar store to AMDGPU DS store inline assembly +template +static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + + FailureOr suffix = getDSStoreSuffix(bitWidth, isRDNAArch); + if (failed(suffix)) + return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); + + // Build inline assembly: "ds_write_b32 $0, $1" + std::string asmStr = ("ds_write_" + *suffix + " $0, $1").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + StringRef constraints = "v,v"; + + // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefAddress<32, 3>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); + + // Create inline assembly operation (no result for store, DS uses 32-bit + // addresses) + createInlineAsm(rewriter, loc, {}, {offset, valueToStore}, asmStr, + constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +/// Check if a memref uses AMDGPU fat_raw_buffer address space +static bool usesBufferAddressSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (!memorySpace) + return false; + + // Check for #amdgpu.address_space attribute + if (auto enumAttr = dyn_cast(memorySpace)) + return enumAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer; + + return false; +} + +/// Check if a memref uses workgroup (LDS) address space +static bool usesWorkgroupAddressSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (!memorySpace) + return false; + + // Check for #gpu.address_space attribute + if (auto enumAttr = dyn_cast(memorySpace)) + return enumAttr.getValue() == gpu::AddressSpace::Workgroup; + + return false; +} + +/// Check if a memref uses register space (memspace 128) +static bool usesRegisterSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return intAttr.getInt() == 128; + + return false; +} + +/// Lower memref.copy when destination is in register space - buffer variant +static LogicalResult lowerCopyToRegBuffer(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, + unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + FailureOr suffix = getBufferLoadSuffix(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported buffer copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute byte offset (no indices for full copy) + Value offset = computeMemrefByteOffset<32>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Extract buffer descriptor and base offset + auto [bufferDesc, baseOffset] = extractBufferDescriptor(rewriter, loc, src); + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + // Build constraint with specific VGPR + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v,s"; + + // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); + + createInlineAsm(rewriter, loc, resultType, + ValueRange{finalOffset, bufferDesc}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space - DS variant +static LogicalResult lowerCopyToRegDS(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, unsigned totalBits, + Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + FailureOr suffix = getDSLoadSuffix(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported DS copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute byte offset + Value offset = computeMemrefAddress<32, 3>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Build constraint with specific VGPR + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space - global variant +static LogicalResult lowerCopyToRegGlobal(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, + unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + FailureOr suffix = getGlobalLoadSuffix(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute source address + Value addr = computeMemrefAddress<64, 0>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Build constraint with specific VGPR + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; + + // Build inline assembly: "global_load_b128 $0, $1, off" + std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); + + createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space +static LogicalResult lowerCopyToReg(memref::CopyOp copyOp, IRRewriter &rewriter, + bool isRDNAArch, unsigned vgprOffset) { + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + + // Get destination alloca to find VGPR assignment + auto dstAlloca = dst.getDefiningOp(); + if (!dstAlloca) + return copyOp.emitError("destination must be a memref.alloca"); + + // Get VGPR number from destination alloca + auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + dstAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return copyOp.emitError("destination alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + // Get source type info + auto srcType = cast(src.getType()); + if (!srcType.hasStaticShape()) + return copyOp.emitError("source must have static shape"); + + unsigned totalBits = getBitwidth(srcType); + + // Get result type from destination + auto dstType = cast(dst.getType()); + if (!dstType.hasStaticShape()) + return copyOp.emitError("destination must have static shape"); + + unsigned resultBitWidth = getBitwidth(dstType); + unsigned resultNumElements = (resultBitWidth + 31) / 32; + Type resultType = + VectorType::get(resultNumElements, rewriter.getIntegerType(32)); + + // Dispatch based on source memory space + if (usesBufferAddressSpace(src)) + return lowerCopyToRegBuffer(copyOp, rewriter, isRDNAArch, vgprOffset, + vgprNum, vgprCount, totalBits, resultType); + if (usesWorkgroupAddressSpace(src)) + return lowerCopyToRegDS(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); + return lowerCopyToRegGlobal(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); +} + +/// Lower load from register space to inline assembly +template +static LogicalResult lowerLoadFromReg(LoadOpTy loadOp, IRRewriter &rewriter, + unsigned vgprOffset) { + Value memref; + if constexpr (std::is_same_v) + memref = loadOp.getBase(); + else + memref = loadOp.getMemRef(); + + // Get source alloca to find VGPR assignment + auto srcAlloca = memref.getDefiningOp(); + if (!srcAlloca) + return loadOp.emitError("source must be a memref.alloca"); + + // Get VGPR number from source alloca + auto vgprNumAttr = srcAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + srcAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return loadOp.emitError("source alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build constraint for reading from specific VGPR(s) + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true); + + // Simple v_mov to read from VGPR (compiler will optimize this away) + std::string asmStr = + "; reg_load " + getVGPRRange(vgprOffset, vgprNum, vgprCount); + + Type resultType = loadOp.getResult().getType(); + Type asmType = resultType; + unsigned bitWidth = getBitwidth(resultType); + if (bitWidth < 32) + asmType = rewriter.getIntegerType(32); + + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + + Value asmResult = createInlineAsm(rewriter, loc, asmType, {}, asmStr, + constraints, /*hasSideEffects=*/false) + .getResult(0); + + if (bitWidth < 32) { + auto narrowType = rewriter.getIntegerType(bitWidth); + asmResult = arith::TruncIOp::create(rewriter, loc, narrowType, asmResult); + asmResult = LLVM::BitcastOp::create(rewriter, loc, resultType, asmResult); + } + + rewriter.replaceOp(loadOp, asmResult); + return success(); +} + +/// Lower store to register space to inline assembly +template +static LogicalResult lowerStoreToReg(StoreOpTy storeOp, IRRewriter &rewriter, + unsigned vgprOffset) { + Value memref; + if constexpr (std::is_same_v) + memref = storeOp.getBase(); + else + memref = storeOp.getMemRef(); + + // Get destination alloca to find VGPR assignment + auto dstAlloca = memref.getDefiningOp(); + if (!dstAlloca) + return storeOp.emitError("destination must be a memref.alloca"); + + // Get VGPR number from destination alloca + auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + dstAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return storeOp.emitError("destination alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + + // Build constraint for writing to specific VGPR(s) + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",0"; + + // v_mov to write to VGPR (input constraint 0 ties to output) + std::string asmStr = + "; reg_store " + getVGPRRange(vgprOffset, vgprNum, vgprCount); + + Value valueToStore = storeOp.getValueToStore(); + unsigned bitWidth = getBitwidth(valueToStore.getType()); + if (bitWidth < 32) { + auto intType = rewriter.getIntegerType(bitWidth); + valueToStore = + LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); + auto i32Type = rewriter.getIntegerType(32); + valueToStore = arith::ExtUIOp::create(rewriter, loc, i32Type, valueToStore); + } + + createInlineAsm(rewriter, loc, valueToStore.getType(), valueToStore, asmStr, + constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + +class WaterLowerMemoryOpsPass + : public water::impl::WaterLowerMemoryOpsBase { +public: + using Base::Base; + + void runOnOperation() override { + auto func = getOperation(); + auto chip = amdgpu::Chipset::parse(chipset); + if (failed(chip)) { + func->emitError("invalid chipset: ") << chipset; + return signalPassFailure(); + } + + MLIRContext *ctx = &getContext(); + + unsigned totalVGPRs = + chip->majorVersion >= 12 && chip->minorVersion >= 5 ? 1024 : 256; + + // Check if function has VGPR allocation and insert inline asm directive. + auto vgprAttr = func->getAttrOfType("water.total_vgprs"); + unsigned vgprCount = vgprAttr ? vgprAttr.getInt() : 0; + unsigned vgprStart = totalVGPRs - vgprCount; + + if (vgprCount > 0) { + // Add amdgpu-num-vgpr to passthrough attribute list + auto vgprStartAttr = StringAttr::get(ctx, std::to_string(vgprStart)); + auto nameAttr = StringAttr::get(ctx, "amdgpu-num-vgpr"); + + Attribute passthroughAttr; + // Get existing passthrough or create new one + if (auto existingPassthrough = + func->getAttrOfType("passthrough")) { + SmallVector attrs(existingPassthrough.begin(), + existingPassthrough.end()); + attrs.push_back(ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})); + passthroughAttr = ArrayAttr::get(ctx, attrs); + } else { + passthroughAttr = ArrayAttr::get( + ctx, {ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})}); + } + func->setAttr("passthrough", passthroughAttr); + } + + // Insert inline assembly at the beginning of the function. + Block &entryBlock = func.getFunctionBody().front(); + IRRewriter rewriter(ctx); + rewriter.setInsertionPointToStart(&entryBlock); + + if (vgprCount > 0) { + std::string asmStr = "; vgprCount = " + std::to_string(vgprCount) + + " vgprStart = " + std::to_string(vgprStart); + + createInlineAsm(rewriter, func.getLoc(), /*resultTypes=*/{}, + /*operands=*/{}, asmStr, /*constraints=*/"", + /*hasSideEffects=*/true); + } + + // Determine if we're targeting RDNA vs CDNA architecture, CDNA has + // different buffer ops format. + bool isRDNAArch = isRDNA(*chip); + + // Helper to dispatch to the appropriate lowering function based on address + // space + auto lowerMemoryOp = [&](Value base, auto lowerRegister, auto lowerBuffer, + auto lowerWorkgroup, + auto lowerGlobal) -> LogicalResult { + if (usesRegisterSpace(base)) + return lowerRegister(); + if (usesBufferAddressSpace(base)) + return lowerBuffer(); + if (usesWorkgroupAddressSpace(base)) + return lowerWorkgroup(); + return lowerGlobal(); + }; + + auto walkFn = [&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + loadOp.getBase(), + [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, + [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto storeOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + storeOp.getBase(), + [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, + [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto loadOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + loadOp.getMemRef(), + [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, + [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto storeOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + storeOp.getMemRef(), + [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, + [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto copyOp = dyn_cast(op)) { + // Only lower copy if destination is in register space + if (usesRegisterSpace(copyOp.getTarget())) { + if (failed(lowerCopyToReg(copyOp, rewriter, isRDNAArch, vgprStart))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + } + return WalkResult::advance(); + }; + + if (func.walk(walkFn).wasInterrupted()) + signalPassFailure(); + + // Clean up register space allocas - they should all be lowered by now + WalkResult cleanupResult = func.walk([&](memref::AllocaOp allocaOp) { + if (usesRegisterSpace(allocaOp.getMemref())) { + if (!allocaOp->use_empty()) { + allocaOp->emitError("register space alloca still has uses after " + "lowering - not all operations were lowered"); + return WalkResult::interrupt(); + } + rewriter.eraseOp(allocaOp); + } + return WalkResult::advance(); + }); + + if (cleanupResult.wasInterrupted()) + signalPassFailure(); + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp new file mode 100644 index 000000000..505fd31a0 --- /dev/null +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -0,0 +1,268 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERMATERIALIZEREGCOPY +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Check if a memref type is in virtual register space (memspace 128). +static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; +} + +static SmallVector getZeroIndices(IRRewriter &rewriter, Location loc, + unsigned rank) { + return {rank, arith::ConstantIndexOp::create(rewriter, loc, 0)}; +} + +static void createLoads(IRRewriter &rewriter, Location loc, Value value, + unsigned rank, Value tempAlloca, Operation *op) { + // Group uses by block and find the first use in each block + DenseMap blockToFirstUse; + for (OpOperand &use : value.getUses()) { + Operation *userOp = use.getOwner(); + Block *userBlock = userOp->getBlock(); + auto it = blockToFirstUse.find(userBlock); + if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) + blockToFirstUse[userBlock] = userOp; + } + + SmallVector zeroIndices = getZeroIndices(rewriter, loc, rank); + + // Create one load per block, right before the first use in that block + DenseMap blockToLoad; + for (auto &[block, firstUse] : blockToFirstUse) { + rewriter.setInsertionPoint(firstUse); + Value load; + if (isa(op)) + load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + else if (auto vecLoadOp = dyn_cast(op)) + load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), + tempAlloca, zeroIndices); + blockToLoad[block] = load; + } + + // Replace uses with the appropriate load for their block + for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) { + Block *userBlock = use.getOwner()->getBlock(); + use.set(blockToLoad[userBlock]); + } +} + +/// Transform a single load operation to use register space copy. +static LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { + Location loc = op->getLoc(); + rewriter.setInsertionPoint(op); + + // Extract memref, indices, and element type from either load type + Value memref, loadResult; + ValueRange indices; + Type elementType; + SmallVector loadShape; + + if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getMemRef(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + elementType = loadOp.getType(); + loadShape.resize(indices.size(), 1); + } else if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getBase(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + VectorType vecType = loadOp.getVectorType(); + elementType = vecType.getElementType(); + loadShape.resize(indices.size() - vecType.getRank(), 1); + llvm::append_range(loadShape, vecType.getShape()); + } else { + return op->emitError("unsupported load operation"); + } + + auto memrefType = cast(memref.getType()); + + // Create subview parameters + Attribute one = rewriter.getIndexAttr(1); + SmallVector offsets, sizes, strides; + for (auto [index, shape] : llvm::zip(indices, loadShape)) { + offsets.push_back(index); + sizes.push_back(rewriter.getIndexAttr(shape)); + strides.push_back(one); + } + + // Create subview of size [1, 1, ..., 1] at the load indices + auto subviewType = + memref::SubViewOp::inferResultType(memrefType, offsets, sizes, strides); + auto subviewMemRefType = cast(subviewType); + Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, + memref, offsets, sizes, strides); + + // Create temporary buffer in virtual register space (memspace 128) + auto regMemSpace = rewriter.getI32IntegerAttr(128); + auto tempType = + MemRefType::get(subviewMemRefType.getShape(), elementType, + /*layout=*/MemRefLayoutAttrInterface{}, regMemSpace); + Value tempAlloca = memref::AllocaOp::create(rewriter, loc, tempType, + /*dynamicSizes=*/ValueRange{}, + /*alignment=*/IntegerAttr()); + + // Copy from subview to temp register buffer + memref::CopyOp::create(rewriter, loc, subview, tempAlloca); + + createLoads(rewriter, loc, loadResult, loadShape.size(), tempAlloca, op); + + // Erase the original load + rewriter.eraseOp(op); + return success(); +} + +/// Hoist allocas from loops when their loads are yielded. +static void hoistAllocasFromLoop(IRRewriter &rewriter, scf::ForOp loop) { + auto yieldedValues = loop.getYieldedValuesMutable(); + if (!yieldedValues) + return; + + auto loopResults = loop.getLoopResults(); + if (!loopResults) + return; + + auto loopInits = loop.getInitsMutable(); + + Block *body = loop.getBody(); + Location loc = loop.getLoc(); + + DominanceInfo dom; + + // Find yielded values that come from loads of memspace 128 allocas + for (auto [idx, yieldedValue, iterArg, init, result] : llvm::enumerate( + *yieldedValues, loop.getRegionIterArgs(), loopInits, *loopResults)) { + // Check if this is a load from memspace 128 + Operation *defOp = yieldedValue.get().getDefiningOp(); + if (!defOp) + continue; + + Value alloca; + ValueRange loadIndices; + if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getMemRef(); + loadIndices = loadOp.getIndices(); + } else if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getBase(); + loadIndices = loadOp.getIndices(); + } else { + continue; + } + + // Check all indices are zero + if (llvm::any_of(loadIndices, + [](Value idx) { return getConstantIntValue(idx) != 0; })) + continue; + + // Check if loading from memspace 128 alloca defined in this loop + auto allocaOp = alloca.getDefiningOp(); + if (!allocaOp) + continue; + if (!isInRegisterSpace(cast(alloca.getType()))) + continue; + if (!body->findAncestorOpInBlock(*allocaOp)) + continue; + + // If load dominates any use of the iter arg, we can't hoist the alloca + // because the load would be invalidated by the store. + bool dominates = false; + for (Operation *user : iterArg.getUsers()) { + if (dom.dominates(defOp, user)) { + dominates = true; + break; + } + } + if (dominates) + continue; + + // Hoist the alloca before the loop + allocaOp->moveBefore(loop); + rewriter.setInsertionPointAfter(allocaOp); + + SmallVector zeroIndices = + getZeroIndices(rewriter, loc, loadIndices.size()); + + // Store the iter arg into the alloca + if (isa(defOp)) { + memref::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + vector::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); + } + + // Create iter arg loads + createLoads(rewriter, loc, iterArg, loadIndices.size(), alloca, defOp); + + // Create a load after the loop + rewriter.setInsertionPointAfter(loop); + zeroIndices = getZeroIndices(rewriter, loc, loadIndices.size()); + Value loadAfterLoop; + if (isa(defOp)) { + loadAfterLoop = + memref::LoadOp::create(rewriter, loc, alloca, zeroIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + loadAfterLoop = vector::LoadOp::create( + rewriter, loc, vectorLoad.getVectorType(), alloca, zeroIndices); + } + + // Replace uses of the loop result with the new load + result.replaceAllUsesWith(loadAfterLoop); + } +} + +/// Materialize register copies by routing memref.load through temporary +/// buffers in virtual register space (memspace 128). +class WaterMaterializeRegCopyPass + : public water::impl::WaterMaterializeRegCopyBase< + WaterMaterializeRegCopyPass> { +public: + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + + // Collect all load operations to transform + SmallVector loadsToTransform; + getOperation()->walk([&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + if (!isInRegisterSpace(cast(loadOp.getMemRef().getType()))) + loadsToTransform.push_back(op); + } else if (auto loadOp = dyn_cast(op)) { + if (!isInRegisterSpace(cast(loadOp.getBase().getType()))) + loadsToTransform.push_back(op); + } + }); + + for (Operation *op : loadsToTransform) { + if (failed(materializeRegCopy(rewriter, op))) + return signalPassFailure(); + } + + // Hoist allocas out of loops when their loads are yielded + getOperation()->walk( + [&](scf::ForOp forOp) { hoistAllocasFromLoop(rewriter, forOp); }); + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterNumberRegisters.cpp b/water/lib/Transforms/WaterNumberRegisters.cpp new file mode 100644 index 000000000..063ed7cf4 --- /dev/null +++ b/water/lib/Transforms/WaterNumberRegisters.cpp @@ -0,0 +1,109 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERNUMBERREGISTERS +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Check if a memref type is in virtual register space (memspace 128). +static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; +} + +/// Calculate the number of 32-bit registers needed for a memref type. +static FailureOr getRegisterCount(MemRefType memrefType) { + // Calculate total size in bytes + unsigned elementSizeBytes = memrefType.getElementTypeBitWidth() / 8; + unsigned numElements = 1; + for (int64_t dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) + return failure(); // Can't allocate dynamic sizes in registers. + + numElements *= dim; + } + + unsigned totalBytes = elementSizeBytes * numElements; + + // Each register is 32 bits = 4 bytes + // Round up to next register boundary. + return (totalBytes + 3) / 4; +} + +/// Assign physical registers to register space allocas. +class WaterNumberRegistersPass + : public water::impl::WaterNumberRegistersBase { +public: + void runOnOperation() override { + auto func = getOperation(); + MLIRContext *ctx = &getContext(); + + SmallVector> regCounts; + + Type i32 = IntegerType::get(ctx, 32); + WalkResult result = func->walk([&](memref::AllocaOp allocaOp) { + auto memrefType = allocaOp.getType(); + if (!isInRegisterSpace(memrefType)) + return WalkResult::advance(); + + auto regCount = getRegisterCount(memrefType); + if (failed(regCount)) { + allocaOp->emitError( + "Cannot allocate dynamic-sized memref in register space"); + return WalkResult::interrupt(); + } + + regCounts.emplace_back(*regCount, allocaOp); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) + return signalPassFailure(); + + // Sort by register size to reduce register alignment gaps. + llvm::stable_sort(regCounts, [](const std::pair &a, + const std::pair &b) { + return a.first < b.first; + }); + + // TODO: for now, just assign registers sequentially. In the future, + // we need a liveness analysis to assign registers. + unsigned nextRegister = 0; + + for (auto [regCount, op] : regCounts) { + // Align to regCount boundary. + nextRegister = ((nextRegister + regCount - 1) / regCount) * regCount; + + // Assign starting register number. + op->setAttr("water.vgpr_number", IntegerAttr::get(i32, nextRegister)); + + // Track how many registers this alloca uses. + op->setAttr("water.vgpr_count", IntegerAttr::get(i32, regCount)); + + // Advance to next available register. + nextRegister += regCount; + } + + // Attach metadata to function with total register count. + func->setAttr("water.total_vgprs", IntegerAttr::get(i32, nextRegister)); + } +}; + +} // namespace diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir new file mode 100644 index 000000000..a8f82fce8 --- /dev/null +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -0,0 +1,624 @@ +// RUN: water-opt %s --water-insert-waitcnt | FileCheck %s + +// CHECK-LABEL: func.func @single_load_use +func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @two_loads_use_in_reverse_order +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) +func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] + // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] + %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] + %addA = arith.addf %loadA, %loadA : vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] + %addB = arith.addf %loadB, %addA : vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + + // CHECK: return %[[ADD_B]] + return %addB : vector<4xf32> +} + +// CHECK-LABEL: func.func @lds_barriers +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) +func.func @lds_barriers(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] + // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] + %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: amdgpu.lds_barrier + // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] + amdgpu.lds_barrier + %addA = arith.addf %loadA, %loadA : vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: amdgpu.lds_barrier + // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] + amdgpu.lds_barrier + %addB = arith.addf %loadB, %addA : vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + + // CHECK: return %[[ADD_B]] + return %addB : vector<4xf32> +} + +// CHECK-LABEL: func.func @raw_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @raw_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to memory + // CHECK: vector.store %[[DATA]], %[[MEM]] + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load from same memory - RAW dependency, must wait for store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]] + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @raw_dependency_memref +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: f32, %{{.*}}: index) +func.func @raw_dependency_memref(%memref: memref<1024xf32>, %data: f32, %offset: index) -> f32 { + // Store to memory + // CHECK: memref.store %[[DATA]], %[[MEM]] + memref.store %data, %memref[%offset] : memref<1024xf32> + + // Load from same memory - RAW dependency, must wait for store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]] + %result = memref.load %memref[%offset] : memref<1024xf32> + + // CHECK: return %[[LOAD]] + return %result : f32 +} + +// CHECK-LABEL: func.func @war_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @war_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Load from memory + // CHECK: %[[LOAD:.*]] = vector.load %[[MEM]] + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to same memory - WAR dependency, must wait for load + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store %[[DATA]], %[[MEM]] + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @waw_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA1:.*]]: vector<4xf32>, %[[DATA2:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @waw_dependency(%memref: memref<1024xf32>, %data1: vector<4xf32>, %data2: vector<4xf32>, %offset: index) { + // First store + // CHECK: vector.store %[[DATA1]], %[[MEM]] + vector.store %data1, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Second store to same memory - WAW dependency, must wait for first store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store %[[DATA2]], %[[MEM]] + vector.store %data2, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @raw_dependency_non_zero_waitcnt +func.func @raw_dependency_non_zero_waitcnt(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate two distinct memrefs to guarantee no aliasing + // CHECK: %[[MEM_A:.*]] = memref.alloc() + %memrefA = memref.alloc() : memref<1024xf32> + // CHECK: %[[MEM_B:.*]] = memref.alloc() + %memrefB = memref.alloc() : memref<1024xf32> + + // Store to memory A + // CHECK: vector.store %{{.*}}, %[[MEM_A]] + vector.store %data, %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to memory B (intervening operation, different memref) + // CHECK: vector.store %{{.*}}, %[[MEM_B]] + vector.store %data, %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // Load from memory A - RAW dependency with store to A at distance 1 + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM_A]] + %result = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @workgroup_memory_raw +func.func @workgroup_memory_raw(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate workgroup (LDS) memory + // CHECK: %[[LDS:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + %lds = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // Store to LDS + // CHECK: vector.store %{{.*}}, %[[LDS]] + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from LDS - RAW dependency, should use dsCnt not loadCnt + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[LDS]] + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @mixed_global_and_workgroup +// CHECK-SAME: (%[[GLOBAL:.*]]: memref<1024xf32>, %[[LDS:.*]]: memref<1024xf32, #gpu.address_space>, %{{.*}}: vector<4xf32>, %{{.*}}: index) +func.func @mixed_global_and_workgroup(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to global memory + // CHECK: vector.store %{{.*}}, %[[GLOBAL]] + vector.store %data, %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to LDS (different counter, no dependency) + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store %{{.*}}, %[[LDS]] + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from global - RAW dependency with global store at distance 0 + // (LDS store doesn't count because it's a different counter type) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[GLOBAL]] + %result = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @existing_waitcnt +func.func @existing_waitcnt(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to memory + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Existing wait operation - should clear pending operations + // CHECK: amdgpu.memory_counter_wait load(0) + amdgpu.memory_counter_wait load(0) + + // Another store after the wait + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load requires wait for the second store only (first was already waited on) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @existing_waitcnt_more_strict +func.func @existing_waitcnt_more_strict(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Store to memory + // CHECK: vector.store + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + + // Existing wait operation - should clear pending operations + // Normally, the distance will be 1, but explicit amdgpu.memory_counter_wait + // overrides it. + // CHECK-NOT: amdgpu.memory_counter_wait load(1) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NOT: amdgpu.memory_counter_wait load(1) + amdgpu.memory_counter_wait load(0) + + // CHECK: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + + +// CHECK-LABEL: func.func @control_flow_merge +func.func @control_flow_merge(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Common operation before branching + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // Extra operation in this path + // CHECK: vector.store + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb2: + // No extra operations, just branch to merge point + // CHECK: cf.br + cf.br ^bb3 + +^bb3: + // bb1 branch has distance 1 but bb2 has distance 0, so we need to conservatively + // take 0 + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @control_flow_merge_same_lists +func.func @control_flow_merge_same_lists(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Common operation before branching + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: vector.store + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb2: + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb3: + // both branches has the same distance 1 + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @loop_carried_dependency +func.func @loop_carried_dependency(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // CHECK: scf.for + %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { + // Store in each iteration + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load in the same iteration - RAW dependency with store from this iteration + // In steady state, the backedge brings pending operations from previous iteration + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOADED:.*]] = vector.load + %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Yield uses the load result, which is async, so need to wait for it + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: scf.yield %[[LOADED]] + scf.yield %loaded : vector<4xf32> + } + + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @loop_load_before_store +func.func @loop_load_before_store(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // CHECK: scf.for + %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { + // Load first - in steady state, has RAW dependency with store from previous iteration + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOADED:.*]] = vector.load + %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Store after load - WAR dependency with load in same iteration + // The wait for the load clears it from pending, so this wait is for the load + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store + vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Yield uses load result - load was already waited on by the store, no additional wait needed + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield %[[LOADED]] + scf.yield %loaded : vector<4xf32> + } + + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @memref_copy_raw_source +func.func @memref_copy_raw_source(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { + // Store to source + // CHECK: vector.store + vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy from source - RAW dependency (reads from source that was just written) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @memref_copy_waw_target +func.func @memref_copy_waw_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { + // Store to destination + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy to destination - WAW dependency (writes to target that was just written) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @memref_copy_war_target +func.func @memref_copy_war_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // Load from destination + // CHECK: %[[RESULT:.*]] = vector.load + %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy to destination - WAR dependency (writes to target that was just read) + // The copy's wait also synchronizes the load, so return doesn't need another wait + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @memref_copy_both_dependencies +func.func @memref_copy_both_dependencies(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to source + // CHECK: vector.store + vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to destination + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy needs to wait for both stores: + // - RAW on source (copy reads from source) + // - WAW on target (copy writes to destination) + // Both stores alias with their respective memrefs, so we need wait(0) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // Load from destination after copy - RAW dependency with copy + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[RESULT:.*]] = vector.load + %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @gather_to_lds +func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %src_offset: index, %dst_offset: index) -> vector<4xf32> { + // Store to global memory + // CHECK: vector.store + vector.store %data, %global[%src_offset] : memref<1024xf32>, vector<4xf32> + + // Gather from global to LDS - has both RAW (reads from global) and acts as store to LDS + // Should wait for global store using load counter + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: amdgpu.gather_to_lds + amdgpu.gather_to_lds %global[%src_offset], %lds[%dst_offset] : f32, memref<1024xf32>, memref<1024xf32, #gpu.address_space> + + // Load from LDS - RAW dependency with gather writing to LDS + // Should wait for LDS operation using ds counter + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: %[[RESULT:.*]] = vector.load + %result = vector.load %lds[%dst_offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @double_buffering +func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the second buffer copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Cannot skip unfortunately + // CHECK: amdgpu.memory_counter_wait load(0) ds(0) + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @triple_buffering +func.func @triple_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // Skip the second buffer copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the prev copy + // CHECK: amdgpu.memory_counter_wait load(0) ds(1) + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} + + +// CHECK-LABEL: func.func @triple_buffering_reg_space +func.func @triple_buffering_reg_space(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %c0 = arith.constant 0 : index + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %reg = memref.alloca() : memref<4xf32, 128 : i32> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the the prev copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.subview + %subview = memref.subview %current[%offset] [4] [1] : memref<1024xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + + // This copy only depends on buffer 2 iterations ago + // CHECK: amdgpu.memory_counter_wait ds(2) + // CHECK: memref.copy + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @load_store_repeated +func.func @load_store_repeated(%src0: memref<4xf32>, %src1: memref<4xf32>, %offset: index) { + %c0 = arith.constant 0 : index + %buff0 = memref.alloc() : memref<4xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<4xf32, #gpu.address_space> + %reg0 = memref.alloca() : memref<4xf32, 128 : i32> + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + %reg2 = memref.alloca() : memref<4xf32, 128 : i32> + %reg3 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK-COUNT-4: memref.copy + memref.copy %src0, %reg0 : memref<4xf32> to memref<4xf32, 128 : i32> + memref.copy %src1, %reg1 : memref<4xf32> to memref<4xf32, 128 : i32> + + memref.copy %buff0, %reg2 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> + memref.copy %buff1, %reg3 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: vector.load + %data0 = vector.load %reg0[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.load + %data1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK-NEXT: vector.load + %data2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: vector.load + %data3 = vector.load %reg3[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + return +} diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir new file mode 100644 index 000000000..0458bc81e --- /dev/null +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -0,0 +1,557 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx950}))' | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx1200}))' | FileCheck %s --check-prefixes=CHECK,GFX12 + +// Test lowering of vector memory operations to AMDGPU global_load/store inline assembly + +// CHECK-LABEL: func.func @simple_function +func.func @simple_function(%arg0: f32) -> f32 { + // CHECK: return %arg0 + return %arg0 : f32 +} + +// CHECK-LABEL: func.func @vector_load +func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: memref.extract_aligned_pointer_as_index + // CHECK: arith.index_cast + // CHECK: llvm.inttoptr + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store +func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // CHECK: memref.extract_aligned_pointer_as_index + // CHECK: arith.index_cast + // CHECK: llvm.inttoptr + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return +} + +// CHECK-LABEL: func.func @vector_load_b32 +func.func @vector_load_b32(%memref: memref<1024xf32>, %offset: index) -> vector<1xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @vector_load_b64 +func.func @vector_load_b64(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @vector_load_b96 +func.func @vector_load_b96(%memref: memref<1024xf32>, %offset: index) -> vector<3xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx3 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b96 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @vector_load_b128 +func.func @vector_load_b128(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store_b32 +func.func @vector_store_b32(%memref: memref<1024xf32>, %offset: index, %data: vector<1xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b64 +func.func @vector_store_b64(%memref: memref<1024xf32>, %offset: index, %data: vector<2xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b96 +func.func @vector_store_b96(%memref: memref<1024xf32>, %offset: index, %data: vector<3xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx3 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b96 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b128 +func.func @vector_store_b128(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @load_store_sequence +func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { + // Test lowering of load/store sequence + + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> + + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} + +// ----- +// Buffer operations tests + +// CHECK-LABEL: func.func @buffer_load_b32 +func.func @buffer_load_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<1xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b64 +func.func @buffer_load_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<2xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx2 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b64 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b96 +func.func @buffer_load_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<3xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx3 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b96 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b128 +func.func @buffer_load_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<4xf32> { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @buffer_store_b32 +func.func @buffer_store_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<1xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b64 +func.func @buffer_store_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<2xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx2 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b64 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b96 +func.func @buffer_store_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<3xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx3 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b96 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b128 +func.func @buffer_store_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<4xf32>) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @mixed_global_and_buffer +func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) { + // Load from global memory (should use global_load) + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to buffer memory (should use buffer_store) + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" + vector.store %global_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + // Load from buffer memory (should use buffer_load) + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" + %buffer_data = vector.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + // Store to global memory (should use global_store) + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %buffer_data, %global[%offset] : memref<1024xf32>, vector<4xf32> + + return +} +// ----- +// DS operations tests + +// CHECK-LABEL: func.func @ds_load_b32 +func.func @ds_load_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<1xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @ds_load_b64 +func.func @ds_load_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<2xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b64 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @ds_load_b96 +func.func @ds_load_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<3xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b96 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @ds_load_b128 +func.func @ds_load_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<4xf32> { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @ds_store_b32 +func.func @ds_store_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<1xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b64 +func.func @ds_store_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<2xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b64 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b96 +func.func @ds_store_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<3xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b96 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b128 +func.func @ds_store_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<4xf32>) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @mixed_global_buffer_and_ds +func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %lds: memref<1024xf32, #gpu.address_space>, %offset: index) { + // Load from global (should use global_load) + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to LDS (should use ds_write) + // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" + vector.store %global_data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from LDS (should use ds_read) + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" + %lds_data = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Store to buffer (should use buffer_store) + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" + vector.store %lds_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + return +} + +// ----- +// Scalar (memref) operations tests + +// CHECK-LABEL: func.func @scalar_load_global_f32 +func.func @scalar_load_global_f32(%memref: memref<1024xf32>, %offset: index) -> f32 { + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %result = memref.load %memref[%offset] : memref<1024xf32> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_load_global_f64 +func.func @scalar_load_global_f64(%memref: memref<1024xf64>, %offset: index) -> f64 { + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + %result = memref.load %memref[%offset] : memref<1024xf64> + return %result : f64 +} + +// CHECK-LABEL: func.func @scalar_store_global_f32 +func.func @scalar_store_global_f32(%memref: memref<1024xf32>, %offset: index, %data: f32) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xf32> + return +} + +// CHECK-LABEL: func.func @scalar_store_global_f64 +func.func @scalar_store_global_f64(%memref: memref<1024xf64>, %offset: index, %data: f64) { + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xf64> + return +} + +// CHECK-LABEL: func.func @scalar_load_buffer_f32 +func.func @scalar_load_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> f32 { + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" + %result = memref.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_store_buffer_f32 +func.func @scalar_store_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: f32) { + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" + memref.store %data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> + return +} + +// CHECK-LABEL: func.func @scalar_load_ds_f32 +func.func @scalar_load_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> f32 { + // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" + %result = memref.load %lds[%offset] : memref<1024xf32, #gpu.address_space> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_store_ds_f32 +func.func @scalar_store_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: f32) { + // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" + memref.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @mixed_scalar_and_vector +func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { + // Scalar load + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %scalar = memref.load %memref[%offset] : memref<1024xf32> + + // Vector load + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %vector = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Scalar store + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %scalar, %memref[%offset] : memref<1024xf32> + + // Vector store + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %vector, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + return +} + +// Test copy to register space with pre-numbered allocas + +// CHECK-LABEL: func.func @copy_global_to_reg_scalar +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes {water.total_vgprs = 1 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v255},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" + memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" + // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" + %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK-NOT: memref.alloca + return %val : f32 +} + +// CHECK-LABEL: func.func @copy_global_to_reg_vector +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @copy_buffer_to_reg +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #amdgpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @copy_workgroup_to_reg +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @store_to_reg +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +func.func @store_to_reg(%val: f32) -> f32 attributes {water.total_vgprs = 1 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" + // GFX12: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" + memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" + // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" + %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK-NOT: memref.alloca + return %result : f32 +} + +// CHECK-LABEL: func.func @multiple_reg_allocas +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, #gpu.address_space>) -> (f32, vector<4xf32>, vector<4xf32>) attributes {water.total_vgprs = 9 : i32} { + %c0 = arith.constant 0 : index + %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v247},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" + %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[248:251]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" + %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> + memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v247", "={v247}" + // GFX12: llvm.inline_asm "; reg_load v247", "={v247}" + %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> + // GFX9: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" + // GFX12: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" + %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> +} + +// ----- +// Test MFMA hazard handling with s_nop insertion + +// CHECK-LABEL: func.func @mfma_hazard_store +func.func @mfma_hazard_store(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + // Perform MFMA operation + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // Store MFMA result - should trigger hazard handling + // CHECK: rocdl.sched.barrier + // CHECK: arith.constant 4 : i16 + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> + + return +} + +// CHECK-LABEL: func.func @mfma_hazard_with_extract +func.func @mfma_hazard_with_extract(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + // MFMA with vector extract - hazard checking should propagate through extract + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + %extracted = vector.extract %result[0] : f32 from vector<4xf32> + + // Store extracted value - should still detect hazard through propagation + // CHECK: rocdl.sched.barrier + // CHECK: arith.constant 4 : i16 + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %extracted, %arg0[%offset] : memref<1024xf32> + + return +} + +// CHECK-LABEL: func.func @no_hazard_with_existing_nop +func.func @no_hazard_with_existing_nop(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // Manually insert s.nop + %nop_count = arith.constant 4 : i16 + llvm.call_intrinsic "llvm.amdgcn.s.nop"(%nop_count) : (i16) -> () + + // Store should NOT insert another s.nop since one already exists + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // CHECK-NOT: rocdl.sched.barrier + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> + + return +} diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir new file mode 100644 index 000000000..7f789e8c9 --- /dev/null +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -0,0 +1,165 @@ +// RUN: water-opt %s --water-materialize-reg-copy | FileCheck %s + +// CHECK-LABEL: func @test_simple_load +func.func @test_simple_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> f32 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] + // CHECK-SAME: memref<10x20xf32> to memref<1x1xf32, strided<[20, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i, %j] : memref<10x20xf32> + return %0 : f32 +} + +// CHECK-LABEL: func @test_simple_vector_load +func.func @test_simple_vector_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> vector<4xf32> { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 4] [1, 1] + // CHECK-SAME: memref<10x20xf32> to memref<1x4xf32, strided<[20, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x4xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = vector.load %[[TEMP]][%[[C0]], %[[C0]]] + // CHECK: return %[[RESULT]] + %0 = vector.load %arg0[%i, %j] : memref<10x20xf32>, vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @test_1d_load +func.func @test_1d_load(%arg0: memref<100xf16>, %i: index) -> f16 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1] [1] [1] + // CHECK-SAME: memref<100xf16> to memref<1xf16, strided<[1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf16, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i] : memref<100xf16> + return %0 : f16 +} + +// CHECK-LABEL: func @test_3d_load +func.func @test_3d_load(%arg0: memref<8x16x32xi32>, %i: index, %j: index, %k: index) -> i32 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] + // CHECK-SAME: memref<8x16x32xi32> to memref<1x1x1xi32, strided<[512, 32, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1x1xi32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]], %[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i, %j, %k] : memref<8x16x32xi32> + return %0 : i32 +} + +// CHECK-LABEL: func @test_multiple_loads +func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) -> f32 { + // First load: subview, alloca, copy + // CHECK: memref.subview + // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy + %0 = memref.load %arg0[%i, %j] : memref<10x10xf32> + + // Second load: subview, alloca, copy + // CHECK: memref.subview + // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy + %1 = memref.load %arg0[%j, %i] : memref<10x10xf32> + + // Now the actual loads happen right before the addf (late as possible) + // CHECK: memref.load + // CHECK: memref.load + // CHECK: arith.addf + %2 = arith.addf %0, %1 : f32 + return %2 : f32 +} + +// CHECK-LABEL: func @test_skip_memspace_128 +func.func @test_skip_memspace_128(%arg0: memref<10xf32>, %arg1: memref<5xf32, 128 : i32>, %i: index) -> f32 { + // This load should be transformed (from default memspace) + // First: subview, alloca, copy + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %0 = memref.load %arg0[%i] : memref<10xf32> + + // This load should NOT be transformed (already from memspace 128) + // It stays in place + // CHECK: %[[VAL1:.*]] = memref.load %arg1[%arg2] : memref<5xf32, 128 : i32> + %1 = memref.load %arg1[%i] : memref<5xf32, 128 : i32> + + // The load from temp happens late (right before addf) + // CHECK: %[[VAL0:.*]] = memref.load %[[TEMP]][%[[C0]]] + // Note: operands may be reordered + // CHECK: arith.addf + %result = arith.addf %0, %1 : f32 + // CHECK: return + return %result : f32 +} + +// CHECK-LABEL: func @test_control_flow +func.func @test_control_flow(%arg0: memref<10xf32>, %cond: i1, %i: index) -> f32 { + // Load happens once, but value is used in multiple blocks + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %val = memref.load %arg0[%i] : memref<10xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // First block: load happens here before the addf + // CHECK: ^bb1: + // CHECK: %[[CONST1:.*]] = arith.constant 1.0 + // CHECK: %[[LOAD1:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: %[[ADD1:.*]] = arith.addf %[[LOAD1]], %[[CONST1]] + %c1 = arith.constant 1.0 : f32 + %sum1 = arith.addf %val, %c1 : f32 + // CHECK: cf.br ^bb3(%[[ADD1]] + cf.br ^bb3(%sum1 : f32) + +^bb2: + // Second block: another load happens here before the mulf + // CHECK: ^bb2: + // CHECK: %[[CONST2:.*]] = arith.constant 2.0 + // CHECK: %[[LOAD2:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: %[[MUL:.*]] = arith.mulf %[[LOAD2]], %[[CONST2]] + %c2 = arith.constant 2.0 : f32 + %prod = arith.mulf %val, %c2 : f32 + // CHECK: cf.br ^bb3(%[[MUL]] + cf.br ^bb3(%prod : f32) + +^bb3(%result: f32): + // CHECK: ^bb3(%[[RESULT:.*]]: f32): + // CHECK: return %[[RESULT]] + return %result : f32 +} + +// CHECK-LABEL: func @test_loop_hoist +func.func @test_loop_hoist(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index, %init: f32) -> f32 { + %c0 = arith.constant 0 : index + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: arith.constant 0 : index + // CHECK: memref.store %arg4, %[[ALLOCA]] + // CHECK: scf.for %[[IV:.*]] = %arg1 to %arg2 step %arg3 iter_args(%[[ITER_ARG:.*]] = %arg4) + %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (f32) { + // CHECK: memref.load %[[ALLOCA]] + // CHECK: memref.store %{{.*}}, %arg0[%c0] + memref.store %arg, %arg0[%c0] : memref<100xf32> + %alloca = memref.alloca() : memref<1xf32, 128 : i32> + %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview, %alloca : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + %val = memref.load %alloca[%c0] : memref<1xf32, 128 : i32> + // CHECK: memref.subview + // CHECK: memref.copy + // CHECK: memref.load %[[ALLOCA]] + // CHECK: scf.yield + scf.yield %val : f32 + } + // CHECK: memref.load %[[ALLOCA]] + // CHECK: return + return %result : f32 +} diff --git a/water/test/Transforms/number-registers-error.mlir b/water/test/Transforms/number-registers-error.mlir new file mode 100644 index 000000000..4d9f77435 --- /dev/null +++ b/water/test/Transforms/number-registers-error.mlir @@ -0,0 +1,7 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' --verify-diagnostics + +func.func @test_dynamic_size_error(%n: index) { + // expected-error @+1 {{Cannot allocate dynamic-sized memref in register space}} + %reg = memref.alloca(%n) : memref + return +} diff --git a/water/test/Transforms/number-registers.mlir b/water/test/Transforms/number-registers.mlir new file mode 100644 index 000000000..ccc32447b --- /dev/null +++ b/water/test/Transforms/number-registers.mlir @@ -0,0 +1,80 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' | FileCheck %s + +// CHECK-LABEL: func @test_simple_numbering +// CHECK-SAME: attributes {water.total_vgprs = 8 : i32} +func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { + %c0 = arith.constant 0 : index + + // 1xf32 = 4 bytes = 1 register, starts at reg 0 + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} + %reg0 = memref.alloca() : memref<1xf32, 128 : i32> + + // 4xf32 = 16 bytes = 4 registers, starts at reg 4 + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + + // 1xf32 = 4 bytes = 1 register, starts at reg 1 (after reg0) + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 1 : i32} + %reg2 = memref.alloca() : memref<1xf32, 128 : i32> + + %subview0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + + %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> + + return %val0 : f32 +} + +// CHECK-LABEL: func @test_loop_with_registers +// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} +func.func @test_loop_with_registers(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + + // Register allocated outside loop + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} + %reg = memref.alloca() : memref<1xf32, 128 : i32> + + scf.for %iv = %lb to %ub step %step { + %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + memref.store %val, %arg0[%iv] : memref<100xf32> + } + + return +} + +// CHECK-LABEL: func @test_triple_buffering_numbering +// CHECK-SAME: attributes {water.total_vgprs = 12 : i32} +func.func @test_triple_buffering_numbering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %c0 = arith.constant 0 : index + + // Three registers for triple buffering, each 4xf32 = 4 registers + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 0 : i32} + %reg0 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 8 : i32} + %reg2 = memref.alloca() : memref<4xf32, 128 : i32> + + return +} + +// CHECK-LABEL: func @test_mixed_memspaces +// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} +func.func @test_mixed_memspaces(%arg0: memref<100xf32>) { + %c0 = arith.constant 0 : index + + // Non-register space alloca - should not be numbered + // CHECK: memref.alloca() : memref<10xf32> + // CHECK-NOT: water.vgpr_number + %local = memref.alloca() : memref<10xf32> + + // Register space alloca - should be numbered + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} + %reg = memref.alloca() : memref<1xf32, 128 : i32> + + return +} diff --git a/water/tools/water-opt/water-opt.cpp b/water/tools/water-opt/water-opt.cpp index 2ab056b5b..dec98b61f 100644 --- a/water/tools/water-opt/water-opt.cpp +++ b/water/tools/water-opt/water-opt.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Transforms/Passes.h" @@ -49,6 +50,7 @@ void registerWaterTestDialect(DialectRegistry ®istry); int main(int argc, char **argv) { mlir::arith::registerArithIntRangeOptsPass(); + mlir::memref::registerExpandStridedMetadataPass(); mlir::registerCSEPass(); mlir::registerCanonicalizerPass(); mlir::registerCompositeFixedPointPass(); diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index a349d315d..93570df0f 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -217,7 +217,7 @@ def make_linear_pass_pipeline( """ def make_pass_arguments( - name: str, args: dict[str, Any], root_op: str | None = None + name: str, args: dict[str, Any], root_op: str | Sequence[str] | None = None ) -> str: ret = ( name @@ -226,7 +226,12 @@ def make_pass_arguments( + "}" ) if root_op: - ret = root_op + "(" + ret + ")" + if isinstance(root_op, str): + ret = root_op + "(" + ret + ")" + elif isinstance(root_op, Sequence): + ret = "(".join(root_op) + "(" + ret + ")" * len(root_op) + else: + raise ValueError(f"Invalid root op: {root_op}") return ret return ( @@ -399,7 +404,15 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu mlir_asm = module.operation.get_asm() target_chip = options.target - def add_opt(pipeline): + enable_asm_lowering = True + + def add_asm_pass(*args: Any) -> list[Any]: + if enable_asm_lowering: + return [args] + + return [] + + def add_opt(pipeline: Any) -> list[Any]: if options.optimization_level: return [pipeline] @@ -445,11 +458,18 @@ def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any] llvm_opt_level = 3 if options.optimization_level else 0 dump_intermediates = options.dump_intermediates or "" + gpu_func = ("gpu.module", "gpu.func") + pipeline = [ + *add_asm_pass("water-materialize-reg-copy", {}, gpu_func), + *add_asm_pass("water-insert-waitcnt", {}, gpu_func), + "expand-strided-metadata", "lower-affine", *add_opt(canonicalize_cse), *add_opt("loop-invariant-code-motion"), *add_opt("int-range-optimizations"), + *add_asm_pass("water-number-registers", {}, gpu_func), + *add_asm_pass("water-lower-memory-ops", {"chipset": target_chip}, gpu_func), "convert-scf-to-cf", ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), ("water-alloc-to-alloca", {}, "gpu.module"),