From c9246e32ad9ad358ec6ed220648a75303bf42ee2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Dec 2025 13:27:19 +0100 Subject: [PATCH 001/114] pass stub Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 37 ++++++++++++++++++ water/lib/Transforms/CMakeLists.txt | 3 ++ water/lib/Transforms/WaterInsertWaitcnt.cpp | 37 ++++++++++++++++++ water/lib/Transforms/WaterLowerMemoryOps.cpp | 38 ++++++++++++++++++ water/test/Transforms/insert-waitcnt.mlir | 40 +++++++++++++++++++ water/test/Transforms/lower-memory-ops.mlir | 41 ++++++++++++++++++++ 6 files changed, 196 insertions(+) create mode 100644 water/lib/Transforms/WaterInsertWaitcnt.cpp create mode 100644 water/lib/Transforms/WaterLowerMemoryOps.cpp create mode 100644 water/test/Transforms/insert-waitcnt.mlir create mode 100644 water/test/Transforms/lower-memory-ops.mlir diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index a241b464c..44c69924a 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -153,4 +153,41 @@ def WaterDropTransformOpsPass : Pass<"water-drop-transform-ops"> { }]; } +def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> { + let summary = "Insert wait instructions for asynchronous memory operations"; + let description = [{ + This pass analyzes asynchronous memory operations and inserts appropriate + wait/synchronization instructions to ensure memory operations complete + before their results are used. + + The pass tracks dependencies between memory operations and register uses, + maintaining scoreboards to determine when waits are necessary. It handles: + - Read-after-write (RAW) dependencies + - Write-after-write (WAW) dependencies + - Write-after-read (WAR) dependencies + + This is analogous to LLVM's SIInsertWaitcnts pass but operates at the + MLIR level for AMDGPU dialect operations. + }]; + let dependentDialects = [ + "::mlir::amdgpu::AMDGPUDialect", + ]; +} + +def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { + let summary = "Lower high-level memory operations to AMDGPU dialect"; + let description = [{ + This pass lowers high-level memory operations (vector.load, vector.store, + memref operations) to AMDGPU-specific memory operations (buffer loads/stores, + LDS operations, etc.). + + This lowering prepares the IR for subsequent waitcnt insertion and + final code generation. + }]; + let dependentDialects = [ + "::mlir::amdgpu::AMDGPUDialect", + "::mlir::LLVM::LLVMDialect", + ]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index baef65d8f..900539ad5 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -7,6 +7,8 @@ add_mlir_dialect_library(MLIRWaterTransforms GPUModuleToBinary.cpp GPUToGPURuntime.cpp SLPVectorizer.cpp + WaterInsertWaitcnt.cpp + WaterLowerMemoryOps.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/water @@ -15,6 +17,7 @@ add_mlir_dialect_library(MLIRWaterTransforms MLIRWaterPassesIncGen LINK_LIBS PUBLIC + MLIRAMDGPUDialect MLIRAnalysis MLIRArithDialect MLIRControlFlowDialect diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp new file mode 100644 index 000000000..df650bf24 --- /dev/null +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -0,0 +1,37 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERINSERTWAITCNT +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Pass that inserts wait/synchronization instructions for asynchronous +/// memory operations. This is analogous to LLVM's SIInsertWaitcnts pass. +class WaterInsertWaitcntPass + : public water::impl::WaterInsertWaitcntBase { +public: + void runOnOperation() override { + // TODO: Implement the pass logic + // This will involve: + // 1. Tracking asynchronous memory operations + // 2. Maintaining scoreboards for dependencies (similar to WaitcntBrackets) + // 3. Inserting amdgpu.waitcnt instructions as needed + // 4. Handling different counter types (vmcnt, lgkmcnt, expcnt, vscnt) + } +}; + +} // namespace diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp new file mode 100644 index 000000000..3d0a0bdf0 --- /dev/null +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -0,0 +1,38 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERLOWERMEMORYOPS +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Pass that lowers high-level memory operations to AMDGPU-specific +/// memory operations (buffer loads/stores, LDS operations, etc.). +class WaterLowerMemoryOpsPass + : public water::impl::WaterLowerMemoryOpsBase { +public: + void runOnOperation() override { + // TODO: Implement the pass logic + // This will involve: + // 1. Pattern matching on vector.load/store, memref operations + // 2. Lowering to amdgpu.raw_buffer_load/store + // 3. Lowering to amdgpu.lds_barrier and related ops + // 4. Handling address space conversions + } +}; + +} // namespace diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir new file mode 100644 index 000000000..bae458f02 --- /dev/null +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -0,0 +1,40 @@ +// RUN: water-opt %s --water-insert-waitcnt | FileCheck %s + +// Smoke test to verify waitcnt insertion pass runs without crashing + +// CHECK-LABEL: func.func @simple_function +func.func @simple_function(%arg0: f32) -> f32 { + // CHECK: return %arg0 + return %arg0 : f32 +} + +// CHECK-LABEL: func.func @vector_load +func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store +func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return +} + +// CHECK-LABEL: func.func @load_store_sequence +func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { + // Test a simple load followed by a store + // TODO: Eventually this should insert waitcnt between load and store + + // CHECK: vector.load + %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir new file mode 100644 index 000000000..1df24265e --- /dev/null +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -0,0 +1,41 @@ +// RUN: water-opt %s --water-lower-memory-ops | FileCheck %s + +// Smoke test to verify memory ops lowering pass runs without crashing + +// CHECK-LABEL: func.func @simple_function +func.func @simple_function(%arg0: f32) -> f32 { + // CHECK: return %arg0 + return %arg0 : f32 +} + +// CHECK-LABEL: func.func @vector_load +func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // TODO: Eventually should lower to amdgpu.raw_buffer_load + // CHECK: vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store +func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // TODO: Eventually should lower to amdgpu.raw_buffer_store + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: return + return +} + +// CHECK-LABEL: func.func @load_store_sequence +func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { + // Test lowering of load/store sequence + + // CHECK: vector.load + %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} From fbb84abc64aeb10166448db98b697f9981136e66 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Dec 2025 16:10:47 +0100 Subject: [PATCH 002/114] read/write lowering Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 2 + water/lib/Transforms/WaterLowerMemoryOps.cpp | 150 +++++++++++++++++-- water/test/Transforms/lower-memory-ops.mlir | 25 +++- 3 files changed, 161 insertions(+), 16 deletions(-) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 44c69924a..e213fa8e5 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -187,6 +187,8 @@ def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { let dependentDialects = [ "::mlir::amdgpu::AMDGPUDialect", "::mlir::LLVM::LLVMDialect", + "::mlir::memref::MemRefDialect", + "::mlir::vector::VectorDialect", ]; } diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 3d0a0bdf0..2ca3a0ba2 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -7,9 +7,14 @@ #include "water/Transforms/Passes.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/Operation.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -20,18 +25,145 @@ namespace mlir::water { namespace { -/// Pass that lowers high-level memory operations to AMDGPU-specific -/// memory operations (buffer loads/stores, LDS operations, etc.). +/// Get the AMDGPU global_load instruction suffix based on bit width +static StringRef getGlobalLoadSuffix(unsigned bitWidth) { + switch (bitWidth) { + case 32: + return "b32"; + case 64: + return "b64"; + case 96: + return "b96"; + case 128: + return "b128"; + default: + return ""; + } +} + +/// Get the AMDGPU global_store instruction suffix based on bit width +static StringRef getGlobalStoreSuffix(unsigned bitWidth) { + return getGlobalLoadSuffix(bitWidth); +} + +/// Pattern to lower vector.load to LLVM inline assembly (global_load_*) +struct VectorLoadToInlineAsmPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + StringRef suffix = getGlobalLoadSuffix(bitWidth); + if (suffix.empty()) + return failure(); + + Location loc = loadOp.getLoc(); + + // Build the inline assembly string: "global_load_b64 $0, $1, off" + std::string asmStr = ("global_load_" + suffix + " $0, $1, off").str(); + + // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) + std::string constraints = "=v,v"; + + // Get the base pointer - need to convert memref to pointer + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + + // Extract pointer as index, cast to i64, then to ptr + Value basePtr = rewriter.create( + loc, loadOp.getBase()); + basePtr = rewriter.create(loc, i64Type, basePtr); + basePtr = rewriter.create(loc, ptrType, basePtr); + + // Create the inline assembly operation + auto asmOp = rewriter.create( + loc, + /*resultTypes=*/vectorType, + /*operands=*/ValueRange{basePtr}, + /*asm_string=*/asmStr, + /*constraints=*/constraints, + /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); + } +}; + +/// Pattern to lower vector.store to LLVM inline assembly (global_store_*) +struct VectorStoreToInlineAsmPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + StringRef suffix = getGlobalStoreSuffix(bitWidth); + if (suffix.empty()) + return failure(); + + Location loc = storeOp.getLoc(); + + // Build the inline assembly string: "global_store_b64 $0, $1, off" + std::string asmStr = ("global_store_" + suffix + " $0, $1, off").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + std::string constraints = "v,v"; + + // Get the base pointer + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + + // Extract pointer as index, cast to i64, then to ptr + Value basePtr = rewriter.create( + loc, storeOp.getBase()); + basePtr = rewriter.create(loc, i64Type, basePtr); + basePtr = rewriter.create(loc, ptrType, basePtr); + + // Create the inline assembly operation (no result for store) + rewriter.create( + loc, + /*resultTypes=*/TypeRange{}, + /*operands=*/ValueRange{basePtr, storeOp.getValueToStore()}, + /*asm_string=*/asmStr, + /*constraints=*/constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); + + rewriter.eraseOp(storeOp); + return success(); + } +}; + +/// Pass that lowers high-level memory operations to LLVM inline assembly +/// for AMDGPU global memory instructions. class WaterLowerMemoryOpsPass : public water::impl::WaterLowerMemoryOpsBase { public: void runOnOperation() override { - // TODO: Implement the pass logic - // This will involve: - // 1. Pattern matching on vector.load/store, memref operations - // 2. Lowering to amdgpu.raw_buffer_load/store - // 3. Lowering to amdgpu.lds_barrier and related ops - // 4. Handling address space conversions + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // Add patterns for lowering vector.load/store to inline assembly + patterns.add( + context); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } } }; diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 1df24265e..61b350c13 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s --water-lower-memory-ops | FileCheck %s -// Smoke test to verify memory ops lowering pass runs without crashing +// Test lowering of vector memory operations to AMDGPU global_load/store inline assembly // CHECK-LABEL: func.func @simple_function func.func @simple_function(%arg0: f32) -> f32 { @@ -10,8 +10,10 @@ func.func @simple_function(%arg0: f32) -> f32 { // CHECK-LABEL: func.func @vector_load func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // TODO: Eventually should lower to amdgpu.raw_buffer_load - // CHECK: vector.load + // CHECK: memref.extract_aligned_pointer_as_index + // CHECK: arith.index_cast + // CHECK: llvm.inttoptr + // CHECK: llvm.inline_asm "global_load_b128 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return return %result : vector<4xf32> @@ -19,21 +21,30 @@ func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf3 // CHECK-LABEL: func.func @vector_store func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { - // TODO: Eventually should lower to amdgpu.raw_buffer_store - // CHECK: vector.store + // CHECK: memref.extract_aligned_pointer_as_index + // CHECK: arith.index_cast + // CHECK: llvm.inttoptr + // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return return } +// CHECK-LABEL: func.func @vector_load_2xf32 +func.func @vector_load_2xf32(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { + // CHECK: llvm.inline_asm "global_load_b64 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> + return %result : vector<2xf32> +} + // CHECK-LABEL: func.func @load_store_sequence func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { // Test lowering of load/store sequence - // CHECK: vector.load + // CHECK: llvm.inline_asm "global_load_b128 $0, $1, off", "=v,v" %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: vector.store + // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return From 2f9856cb034d0047c86a9ec1aa8547966239fb16 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Dec 2025 16:21:28 +0100 Subject: [PATCH 003/114] funcs Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 211 +++++++++---------- 1 file changed, 99 insertions(+), 112 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 2ca3a0ba2..1ed477a2b 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -12,9 +12,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -46,106 +44,97 @@ static StringRef getGlobalStoreSuffix(unsigned bitWidth) { return getGlobalLoadSuffix(bitWidth); } -/// Pattern to lower vector.load to LLVM inline assembly (global_load_*) -struct VectorLoadToInlineAsmPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::LoadOp loadOp, - PatternRewriter &rewriter) const override { - auto vectorType = loadOp.getVectorType(); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); - - StringRef suffix = getGlobalLoadSuffix(bitWidth); - if (suffix.empty()) - return failure(); - - Location loc = loadOp.getLoc(); - - // Build the inline assembly string: "global_load_b64 $0, $1, off" - std::string asmStr = ("global_load_" + suffix + " $0, $1, off").str(); - - // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) - std::string constraints = "=v,v"; - - // Get the base pointer - need to convert memref to pointer - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto i64Type = rewriter.getI64Type(); - - // Extract pointer as index, cast to i64, then to ptr - Value basePtr = rewriter.create( - loc, loadOp.getBase()); - basePtr = rewriter.create(loc, i64Type, basePtr); - basePtr = rewriter.create(loc, ptrType, basePtr); - - // Create the inline assembly operation - auto asmOp = rewriter.create( - loc, - /*resultTypes=*/vectorType, - /*operands=*/ValueRange{basePtr}, - /*asm_string=*/asmStr, - /*constraints=*/constraints, - /*has_side_effects=*/false, - /*is_align_stack=*/false, - /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, - /*asm_dialect=*/LLVM::AsmDialectAttr{}, - /*operand_attrs=*/ArrayAttr{}); - - rewriter.replaceOp(loadOp, asmOp.getResult(0)); - return success(); - } -}; +/// Lower vector.load to LLVM inline assembly (global_load_*) +static void lowerVectorLoad(vector::LoadOp loadOp, IRRewriter &rewriter) { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + StringRef suffix = getGlobalLoadSuffix(bitWidth); + if (suffix.empty()) + return; + + Location loc = loadOp.getLoc(); + + // Build the inline assembly string: "global_load_b64 $0, $1, off" + std::string asmStr = ("global_load_" + suffix + " $0, $1, off").str(); + + // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) + std::string constraints = "=v,v"; + + // Get the base pointer - need to convert memref to pointer + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + + rewriter.setInsertionPoint(loadOp); + + // Extract pointer as index, cast to i64, then to ptr + Value basePtr = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, loadOp.getBase()); + basePtr = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtr); + basePtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, basePtr); + + // Create the inline assembly operation + auto asmOp = LLVM::InlineAsmOp::create( + rewriter, loc, + /*resultTypes=*/vectorType, + /*operands=*/ValueRange{basePtr}, + /*asm_string=*/asmStr, + /*constraints=*/constraints, + /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); +} -/// Pattern to lower vector.store to LLVM inline assembly (global_store_*) -struct VectorStoreToInlineAsmPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::StoreOp storeOp, - PatternRewriter &rewriter) const override { - auto vectorType = cast(storeOp.getValueToStore().getType()); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); - - StringRef suffix = getGlobalStoreSuffix(bitWidth); - if (suffix.empty()) - return failure(); - - Location loc = storeOp.getLoc(); - - // Build the inline assembly string: "global_store_b64 $0, $1, off" - std::string asmStr = ("global_store_" + suffix + " $0, $1, off").str(); - - // Constraints: "v" for address (VGPR), "v" for data (VGPR) - std::string constraints = "v,v"; - - // Get the base pointer - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto i64Type = rewriter.getI64Type(); - - // Extract pointer as index, cast to i64, then to ptr - Value basePtr = rewriter.create( - loc, storeOp.getBase()); - basePtr = rewriter.create(loc, i64Type, basePtr); - basePtr = rewriter.create(loc, ptrType, basePtr); - - // Create the inline assembly operation (no result for store) - rewriter.create( - loc, - /*resultTypes=*/TypeRange{}, - /*operands=*/ValueRange{basePtr, storeOp.getValueToStore()}, - /*asm_string=*/asmStr, - /*constraints=*/constraints, - /*has_side_effects=*/true, - /*is_align_stack=*/false, - /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, - /*asm_dialect=*/LLVM::AsmDialectAttr{}, - /*operand_attrs=*/ArrayAttr{}); - - rewriter.eraseOp(storeOp); - return success(); - } -}; +/// Lower vector.store to LLVM inline assembly (global_store_*) +static void lowerVectorStore(vector::StoreOp storeOp, IRRewriter &rewriter) { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + StringRef suffix = getGlobalStoreSuffix(bitWidth); + if (suffix.empty()) + return; + + Location loc = storeOp.getLoc(); + + // Build the inline assembly string: "global_store_b64 $0, $1, off" + std::string asmStr = ("global_store_" + suffix + " $0, $1, off").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + std::string constraints = "v,v"; + + // Get the base pointer + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + + rewriter.setInsertionPoint(storeOp); + + // Extract pointer as index, cast to i64, then to ptr + Value basePtr = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, storeOp.getBase()); + basePtr = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtr); + basePtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, basePtr); + + // Create the inline assembly operation (no result for store) + LLVM::InlineAsmOp::create( + rewriter, loc, + /*resultTypes=*/TypeRange{}, + /*operands=*/ValueRange{basePtr, storeOp.getValueToStore()}, + /*asm_string=*/asmStr, + /*constraints=*/constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); + + rewriter.eraseOp(storeOp); +} /// Pass that lowers high-level memory operations to LLVM inline assembly /// for AMDGPU global memory instructions. @@ -153,17 +142,15 @@ class WaterLowerMemoryOpsPass : public water::impl::WaterLowerMemoryOpsBase { public: void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - - // Add patterns for lowering vector.load/store to inline assembly - patterns.add( - context); - - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - signalPassFailure(); - } + IRRewriter rewriter(&getContext()); + + getOperation()->walk([&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + lowerVectorLoad(loadOp, rewriter); + } else if (auto storeOp = dyn_cast(op)) { + lowerVectorStore(storeOp, rewriter); + } + }); } }; From df6e2e2275650802fb945d8c1876cf2c1af42cab Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Dec 2025 16:27:22 +0100 Subject: [PATCH 004/114] refac Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 53 ++++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 1ed477a2b..e4a392ac1 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -45,14 +45,15 @@ static StringRef getGlobalStoreSuffix(unsigned bitWidth) { } /// Lower vector.load to LLVM inline assembly (global_load_*) -static void lowerVectorLoad(vector::LoadOp loadOp, IRRewriter &rewriter) { +static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, + IRRewriter &rewriter) { auto vectorType = loadOp.getVectorType(); unsigned bitWidth = vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); StringRef suffix = getGlobalLoadSuffix(bitWidth); if (suffix.empty()) - return; + return loadOp.emitError("unsupported vector load bit width: ") << bitWidth; Location loc = loadOp.getLoc(); @@ -88,17 +89,20 @@ static void lowerVectorLoad(vector::LoadOp loadOp, IRRewriter &rewriter) { /*operand_attrs=*/ArrayAttr{}); rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); } /// Lower vector.store to LLVM inline assembly (global_store_*) -static void lowerVectorStore(vector::StoreOp storeOp, IRRewriter &rewriter) { +static LogicalResult lowerVectorStore(vector::StoreOp storeOp, + IRRewriter &rewriter) { auto vectorType = cast(storeOp.getValueToStore().getType()); unsigned bitWidth = vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); StringRef suffix = getGlobalStoreSuffix(bitWidth); if (suffix.empty()) - return; + return storeOp.emitError("unsupported vector store bit width: ") + << bitWidth; Location loc = storeOp.getLoc(); @@ -134,8 +138,30 @@ static void lowerVectorStore(vector::StoreOp storeOp, IRRewriter &rewriter) { /*operand_attrs=*/ArrayAttr{}); rewriter.eraseOp(storeOp); + return success(); } +/// Wrapper functions for operation lowering +static LogicalResult lowerLoadOp(Operation *op, IRRewriter &rewriter) { + return lowerVectorLoad(cast(op), rewriter); +} + +static LogicalResult lowerStoreOp(Operation *op, IRRewriter &rewriter) { + return lowerVectorStore(cast(op), rewriter); +} + +/// Operation lowering handler entry +struct OpLoweringHandler { + TypeID typeID; + LogicalResult (*lowerFn)(Operation *, IRRewriter &); +}; + +/// Table of lowering handlers for different operation types +static const OpLoweringHandler loweringHandlers[] = { + {TypeID::get(), lowerLoadOp}, + {TypeID::get(), lowerStoreOp}, +}; + /// Pass that lowers high-level memory operations to LLVM inline assembly /// for AMDGPU global memory instructions. class WaterLowerMemoryOpsPass @@ -144,13 +170,20 @@ class WaterLowerMemoryOpsPass void runOnOperation() override { IRRewriter rewriter(&getContext()); - getOperation()->walk([&](Operation *op) { - if (auto loadOp = dyn_cast(op)) { - lowerVectorLoad(loadOp, rewriter); - } else if (auto storeOp = dyn_cast(op)) { - lowerVectorStore(storeOp, rewriter); + auto walkFn = [&](Operation *op) { + TypeID opTypeID = op->getName().getTypeID(); + for (const auto &handler : loweringHandlers) { + if (handler.typeID == opTypeID) { + if (failed(handler.lowerFn(op, rewriter))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } } - }); + return WalkResult::advance(); + }; + + if (getOperation()->walk(walkFn).wasInterrupted()) + signalPassFailure(); } }; From 35f563bbf67e71fc13924c9934dd7e233ea709fa Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 21:45:36 +0100 Subject: [PATCH 005/114] cleanup Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 72 +++++++++----------- water/test/Transforms/lower-memory-ops.mlir | 6 +- 2 files changed, 35 insertions(+), 43 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index e4a392ac1..79be11051 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -23,25 +23,34 @@ namespace mlir::water { namespace { -/// Get the AMDGPU global_load instruction suffix based on bit width -static StringRef getGlobalLoadSuffix(unsigned bitWidth) { +/// Get the AMDGPU instruction suffix based on bit width +static FailureOr getSizeSuffix(unsigned bitWidth) { switch (bitWidth) { case 32: - return "b32"; + return StringRef("b32"); case 64: - return "b64"; + return StringRef("b64"); case 96: - return "b96"; + return StringRef("b96"); case 128: - return "b128"; + return StringRef("b128"); default: - return ""; + return failure(); } } -/// Get the AMDGPU global_store instruction suffix based on bit width -static StringRef getGlobalStoreSuffix(unsigned bitWidth) { - return getGlobalLoadSuffix(bitWidth); +/// Create an LLVM inline assembly operation with standard attributes +static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, + TypeRange resultTypes, + ValueRange operands, StringRef asmStr, + StringRef constraints, + bool hasSideEffects) { + return LLVM::InlineAsmOp::create( + rewriter, loc, resultTypes, operands, asmStr, constraints, hasSideEffects, + /*is_align_stack=*/false, + /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); } /// Lower vector.load to LLVM inline assembly (global_load_*) @@ -51,17 +60,17 @@ static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, unsigned bitWidth = vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); - StringRef suffix = getGlobalLoadSuffix(bitWidth); - if (suffix.empty()) + FailureOr suffix = getSizeSuffix(bitWidth); + if (failed(suffix)) return loadOp.emitError("unsupported vector load bit width: ") << bitWidth; Location loc = loadOp.getLoc(); // Build the inline assembly string: "global_load_b64 $0, $1, off" - std::string asmStr = ("global_load_" + suffix + " $0, $1, off").str(); + std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) - std::string constraints = "=v,v"; + StringRef constraints = "=v,v"; // Get the base pointer - need to convert memref to pointer auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); @@ -76,17 +85,8 @@ static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, basePtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, basePtr); // Create the inline assembly operation - auto asmOp = LLVM::InlineAsmOp::create( - rewriter, loc, - /*resultTypes=*/vectorType, - /*operands=*/ValueRange{basePtr}, - /*asm_string=*/asmStr, - /*constraints=*/constraints, - /*has_side_effects=*/false, - /*is_align_stack=*/false, - /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, - /*asm_dialect=*/LLVM::AsmDialectAttr{}, - /*operand_attrs=*/ArrayAttr{}); + auto asmOp = createInlineAsm(rewriter, loc, vectorType, ValueRange{basePtr}, + asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); @@ -99,18 +99,18 @@ static LogicalResult lowerVectorStore(vector::StoreOp storeOp, unsigned bitWidth = vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); - StringRef suffix = getGlobalStoreSuffix(bitWidth); - if (suffix.empty()) + FailureOr suffix = getSizeSuffix(bitWidth); + if (failed(suffix)) return storeOp.emitError("unsupported vector store bit width: ") << bitWidth; Location loc = storeOp.getLoc(); // Build the inline assembly string: "global_store_b64 $0, $1, off" - std::string asmStr = ("global_store_" + suffix + " $0, $1, off").str(); + std::string asmStr = ("global_store_" + *suffix + " $0, $1, off").str(); // Constraints: "v" for address (VGPR), "v" for data (VGPR) - std::string constraints = "v,v"; + StringRef constraints = "v,v"; // Get the base pointer auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); @@ -125,17 +125,9 @@ static LogicalResult lowerVectorStore(vector::StoreOp storeOp, basePtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, basePtr); // Create the inline assembly operation (no result for store) - LLVM::InlineAsmOp::create( - rewriter, loc, - /*resultTypes=*/TypeRange{}, - /*operands=*/ValueRange{basePtr, storeOp.getValueToStore()}, - /*asm_string=*/asmStr, - /*constraints=*/constraints, - /*has_side_effects=*/true, - /*is_align_stack=*/false, - /*tail_call_kind=*/LLVM::tailcallkind::TailCallKind::None, - /*asm_dialect=*/LLVM::AsmDialectAttr{}, - /*operand_attrs=*/ArrayAttr{}); + createInlineAsm(rewriter, loc, TypeRange{}, + ValueRange{basePtr, storeOp.getValueToStore()}, asmStr, + constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); return success(); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 61b350c13..69f1617ed 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -13,7 +13,7 @@ func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf3 // CHECK: memref.extract_aligned_pointer_as_index // CHECK: arith.index_cast // CHECK: llvm.inttoptr - // CHECK: llvm.inline_asm "global_load_b128 $0, $1, off", "=v,v" + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return return %result : vector<4xf32> @@ -32,7 +32,7 @@ func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector // CHECK-LABEL: func.func @vector_load_2xf32 func.func @vector_load_2xf32(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { - // CHECK: llvm.inline_asm "global_load_b64 $0, $1, off", "=v,v" + // CHECK: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> return %result : vector<2xf32> } @@ -41,7 +41,7 @@ func.func @vector_load_2xf32(%memref: memref<1024xf32>, %offset: index) -> vecto func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { // Test lowering of load/store sequence - // CHECK: llvm.inline_asm "global_load_b128 $0, $1, off", "=v,v" + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" From 86de819c43fbbd7daa9992705391a8740104e7dd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 21:54:00 +0100 Subject: [PATCH 006/114] indices support Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 76 ++++++++++++++------ 1 file changed, 56 insertions(+), 20 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 79be11051..72f3ae351 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -53,6 +53,52 @@ static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, /*operand_attrs=*/ArrayAttr{}); } +/// Compute the final address for a memref access with indices +static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + + // Extract strided metadata to get base pointer, offset, sizes, and strides + auto metadataOp = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + Value basePtr = metadataOp.getBaseBuffer(); + Value offset = metadataOp.getOffset(); + + // Compute linear index from multidimensional indices + Value linearIndex = offset; + for (size_t i = 0; i < indices.size(); ++i) { + Value stride = metadataOp.getStrides()[i]; + Value indexTimesStride = arith::MulIOp::create( + rewriter, loc, indices[i], stride, arith::IntegerOverflowFlags::nsw); + linearIndex = + arith::AddIOp::create(rewriter, loc, linearIndex, indexTimesStride, + arith::IntegerOverflowFlags::nsw); + } + + // Convert base pointer to i64 + Value basePtrInt = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); + basePtrInt = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtrInt); + + // Convert linear index to i64 and scale by element size + unsigned elementBytes = elementBitWidth / 8; + Value elementSize = + arith::ConstantIndexOp::create(rewriter, loc, elementBytes); + Value byteOffset = + arith::MulIOp::create(rewriter, loc, linearIndex, elementSize, + arith::IntegerOverflowFlags::nsw); + Value byteOffsetI64 = + arith::IndexCastOp::create(rewriter, loc, i64Type, byteOffset); + + // Add byte offset to base pointer + Value finalAddr = + arith::AddIOp::create(rewriter, loc, basePtrInt, byteOffsetI64, + arith::IntegerOverflowFlags::nsw); + return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); +} + /// Lower vector.load to LLVM inline assembly (global_load_*) static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, IRRewriter &rewriter) { @@ -72,20 +118,15 @@ static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, // Constraints: "=v" for output (VGPR), "v" for input address (VGPR) StringRef constraints = "=v,v"; - // Get the base pointer - need to convert memref to pointer - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto i64Type = rewriter.getI64Type(); - rewriter.setInsertionPoint(loadOp); - // Extract pointer as index, cast to i64, then to ptr - Value basePtr = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, loadOp.getBase()); - basePtr = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtr); - basePtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, basePtr); + // Compute the final address + Value addr = + computeMemrefAddress(rewriter, loc, loadOp.getBase(), loadOp.getIndices(), + vectorType.getElementTypeBitWidth()); // Create the inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, vectorType, ValueRange{basePtr}, + auto asmOp = createInlineAsm(rewriter, loc, vectorType, ValueRange{addr}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); @@ -112,21 +153,16 @@ static LogicalResult lowerVectorStore(vector::StoreOp storeOp, // Constraints: "v" for address (VGPR), "v" for data (VGPR) StringRef constraints = "v,v"; - // Get the base pointer - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto i64Type = rewriter.getI64Type(); - rewriter.setInsertionPoint(storeOp); - // Extract pointer as index, cast to i64, then to ptr - Value basePtr = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, storeOp.getBase()); - basePtr = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtr); - basePtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, basePtr); + // Compute the final address + Value addr = computeMemrefAddress(rewriter, loc, storeOp.getBase(), + storeOp.getIndices(), + vectorType.getElementTypeBitWidth()); // Create the inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{basePtr, storeOp.getValueToStore()}, asmStr, + ValueRange{addr, storeOp.getValueToStore()}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); From 4f8b7cc05fc141d5d98f1098308728dbe4cee550 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 21:57:48 +0100 Subject: [PATCH 007/114] seq Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 72f3ae351..a2d2a8cf8 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -68,7 +68,7 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, // Compute linear index from multidimensional indices Value linearIndex = offset; - for (size_t i = 0; i < indices.size(); ++i) { + for (auto i : llvm::seq(0, indices.size())) { Value stride = metadataOp.getStrides()[i]; Value indexTimesStride = arith::MulIOp::create( rewriter, loc, indices[i], stride, arith::IntegerOverflowFlags::nsw); From 971c2f13fe377e05e5216be03bb319ae48898943 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 22:28:00 +0100 Subject: [PATCH 008/114] waitcnt pass wip Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 128 +++++++++++++++++++- 1 file changed, 122 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index df650bf24..e1322dcd7 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -6,11 +6,15 @@ #include "water/Transforms/Passes.h" +#include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; +using namespace mlir::dataflow; namespace mlir::water { #define GEN_PASS_DEF_WATERINSERTWAITCNT @@ -19,18 +23,130 @@ namespace mlir::water { namespace { +/// Shared pending operations list for structural sharing +struct PendingOperations { + SmallVector ops; +}; + +/// Lattice state tracking pending asynchronous operations +class WaitcntState : public AbstractDenseLattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WaitcntState) + + using AbstractDenseLattice::AbstractDenseLattice; + + ChangeResult join(const AbstractDenseLattice &rhs) override { + const auto &rhsState = static_cast(rhs); + + // If rhs has no pending ops, no change needed + if (rhsState.isEmpty()) + return ChangeResult::NoChange; + + // If we have no pending ops, just share rhs's state + if (isEmpty()) { + pendingOps = rhsState.pendingOps; + return ChangeResult::Change; + } + + // Conservative union: merge all pending operations + bool changed = false; + for (Operation *op : rhsState.getPendingOps()) { + if (!contains(op)) { + addPendingOp(op); + changed = true; + } + } + + return changed ? ChangeResult::Change : ChangeResult::NoChange; + } + + void print(raw_ostream &os) const override { + os << "WaitcntState: " << size() << " pending ops"; + } + + /// Add a pending operation (copy-on-write) + void addPendingOp(Operation *op) { + auto newPending = std::make_shared(); + if (pendingOps) + newPending->ops = pendingOps->ops; + newPending->ops.push_back(op); + pendingOps = newPending; + } + + /// Get pending operations (read-only) + ArrayRef getPendingOps() const { + return pendingOps ? ArrayRef(pendingOps->ops) + : ArrayRef(); + } + + /// Check if this operation is in the pending list + bool contains(Operation *op) const { + return pendingOps && llvm::is_contained(pendingOps->ops, op); + } + + /// Check if empty + bool isEmpty() const { return !pendingOps || pendingOps->ops.empty(); } + + /// Get size + size_t size() const { return pendingOps ? pendingOps->ops.size() : 0; } + + /// Initialize to empty state + void clear() { pendingOps = std::make_shared(); } + +private: + /// Pending asynchronous operations (vector loads/stores) + std::shared_ptr pendingOps; +}; + +/// Dense forward dataflow analysis for waitcnt insertion +class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { +public: + using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; + + void setToEntryState(WaitcntState *lattice) override { + // Entry state: empty pending operations + lattice->clear(); + } + + LogicalResult visitOperation(Operation *op, const WaitcntState &before, + WaitcntState *after) override { + // Start with the state before this operation + *after = before; + + // Check if this is an async memory operation (vector load/store) + if (isa(op)) { + // Add this operation to the pending list + after->addPendingOp(op); + } + + // Check if this operation uses a value produced by a pending operation + // If so, we need to insert a waitcnt before this operation + // (We'll handle actual insertion in a separate pass over the IR) + + return success(); + } +}; + /// Pass that inserts wait/synchronization instructions for asynchronous /// memory operations. This is analogous to LLVM's SIInsertWaitcnts pass. class WaterInsertWaitcntPass : public water::impl::WaterInsertWaitcntBase { public: void runOnOperation() override { - // TODO: Implement the pass logic - // This will involve: - // 1. Tracking asynchronous memory operations - // 2. Maintaining scoreboards for dependencies (similar to WaitcntBrackets) - // 3. Inserting amdgpu.waitcnt instructions as needed - // 4. Handling different counter types (vmcnt, lgkmcnt, expcnt, vscnt) + Operation *op = getOperation(); + + // Set up the dataflow solver + DataFlowSolver solver; + solver.load(); + + // Run the analysis + if (failed(solver.initializeAndRun(op))) { + signalPassFailure(); + return; + } + + // TODO: Use the analysis results to insert waitcnt instructions + // For now, just run the analysis to test the framework } }; From 015d87f6bfda17aacf2e99187ce73ac09b5df085 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 22:43:33 +0100 Subject: [PATCH 009/114] reqs Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 88 +++++++++++++++++---- 1 file changed, 72 insertions(+), 16 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index e1322dcd7..a3e16768e 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -28,6 +28,45 @@ struct PendingOperations { SmallVector ops; }; +/// Waitcnt requirement for synchronization +struct WaitcntRequirement { + std::optional vmcnt; // Vector memory operations counter + std::optional lgkmcnt; // LDS/GDS operations counter + std::optional vscnt; // Vector store operations counter + + bool hasRequirement() const { + return vmcnt.has_value() || lgkmcnt.has_value() || vscnt.has_value(); + } + + /// Merge with another requirement (take minimum for conservative join) + /// Returns true if this requirement changed + bool merge(const WaitcntRequirement &other) { + bool changed = false; + + // Take minimum of each counter (lower value = more restrictive) + if (other.vmcnt.has_value()) { + if (!vmcnt.has_value() || *other.vmcnt < *vmcnt) { + vmcnt = other.vmcnt; + changed = true; + } + } + if (other.lgkmcnt.has_value()) { + if (!lgkmcnt.has_value() || *other.lgkmcnt < *lgkmcnt) { + lgkmcnt = other.lgkmcnt; + changed = true; + } + } + if (other.vscnt.has_value()) { + if (!vscnt.has_value() || *other.vscnt < *vscnt) { + vscnt = other.vscnt; + changed = true; + } + } + + return changed; + } +}; + /// Lattice state tracking pending asynchronous operations class WaitcntState : public AbstractDenseLattice { public: @@ -37,26 +76,28 @@ class WaitcntState : public AbstractDenseLattice { ChangeResult join(const AbstractDenseLattice &rhs) override { const auto &rhsState = static_cast(rhs); - - // If rhs has no pending ops, no change needed - if (rhsState.isEmpty()) - return ChangeResult::NoChange; - - // If we have no pending ops, just share rhs's state - if (isEmpty()) { - pendingOps = rhsState.pendingOps; - return ChangeResult::Change; - } - - // Conservative union: merge all pending operations bool changed = false; - for (Operation *op : rhsState.getPendingOps()) { - if (!contains(op)) { - addPendingOp(op); + + // Merge pending operations + if (!rhsState.isEmpty()) { + if (isEmpty()) { + pendingOps = rhsState.pendingOps; changed = true; + } else { + // Conservative union: merge all pending operations + for (Operation *op : rhsState.getPendingOps()) { + if (!contains(op)) { + addPendingOp(op); + changed = true; + } + } } } + // Merge requirements (take minimum for conservative join) + if (requirement.merge(rhsState.requirement)) + changed = true; + return changed ? ChangeResult::Change : ChangeResult::NoChange; } @@ -91,11 +132,26 @@ class WaitcntState : public AbstractDenseLattice { size_t size() const { return pendingOps ? pendingOps->ops.size() : 0; } /// Initialize to empty state - void clear() { pendingOps = std::make_shared(); } + void clear() { + pendingOps = std::make_shared(); + requirement = WaitcntRequirement(); + } + + /// Set the required waitcnt values + void setRequirement(const WaitcntRequirement &req) { requirement = req; } + + /// Get the required waitcnt values + const WaitcntRequirement &getRequirement() const { return requirement; } + + /// Check if there's a waitcnt requirement + bool hasRequirement() const { return requirement.hasRequirement(); } private: /// Pending asynchronous operations (vector loads/stores) std::shared_ptr pendingOps; + + /// Required waitcnt before this state (for inserting actual waitcnt ops) + WaitcntRequirement requirement; }; /// Dense forward dataflow analysis for waitcnt insertion From 1fdeec299ea5e3da809978b8695a3a03cdfa6409 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 22:54:52 +0100 Subject: [PATCH 010/114] reqs Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 46 +++++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index a3e16768e..b397caadd 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -146,6 +146,32 @@ class WaitcntState : public AbstractDenseLattice { /// Check if there's a waitcnt requirement bool hasRequirement() const { return requirement.hasRequirement(); } + /// Check if a value depends on pending operations and compute required wait + std::optional checkRequirement(Value val) const { + if (!pendingOps || pendingOps->ops.empty()) + return std::nullopt; + + // Check if val is produced by any pending operation + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return std::nullopt; + + // Find the operation in the pending list + auto it = llvm::find(pendingOps->ops, defOp); + if (it == pendingOps->ops.end()) + return std::nullopt; + + // Compute distance from the end of the list + size_t distanceFromEnd = std::distance(it, pendingOps->ops.end()) - 1; + + WaitcntRequirement req; + // For vector loads/stores, use vmcnt + if (isa(defOp)) + req.vmcnt = distanceFromEnd; + + return req; + } + private: /// Pending asynchronous operations (vector loads/stores) std::shared_ptr pendingOps; @@ -169,16 +195,28 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Start with the state before this operation *after = before; + // Always reset requirement - it should not propagate from previous state + after->setRequirement(WaitcntRequirement()); + + // Check if any operands depend on pending operations + WaitcntRequirement opRequirement; + for (Value operand : op->getOperands()) { + if (auto req = before.checkRequirement(operand)) { + // Merge this requirement (take minimum for conservative wait) + opRequirement.merge(*req); + } + } + + // Set the requirement for this operation + if (opRequirement.hasRequirement()) + after->setRequirement(opRequirement); + // Check if this is an async memory operation (vector load/store) if (isa(op)) { // Add this operation to the pending list after->addPendingOp(op); } - // Check if this operation uses a value produced by a pending operation - // If so, we need to insert a waitcnt before this operation - // (We'll handle actual insertion in a separate pass over the IR) - return success(); } }; From 83b50f59aeb78fecf9f720c23a5fd9f3876db5cc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 7 Dec 2025 23:09:34 +0100 Subject: [PATCH 011/114] state Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 38 +++++++++++++++------ 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index b397caadd..7fb7c07e0 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -32,10 +32,9 @@ struct PendingOperations { struct WaitcntRequirement { std::optional vmcnt; // Vector memory operations counter std::optional lgkmcnt; // LDS/GDS operations counter - std::optional vscnt; // Vector store operations counter bool hasRequirement() const { - return vmcnt.has_value() || lgkmcnt.has_value() || vscnt.has_value(); + return vmcnt.has_value() || lgkmcnt.has_value(); } /// Merge with another requirement (take minimum for conservative join) @@ -56,15 +55,13 @@ struct WaitcntRequirement { changed = true; } } - if (other.vscnt.has_value()) { - if (!vscnt.has_value() || *other.vscnt < *vscnt) { - vscnt = other.vscnt; - changed = true; - } - } return changed; } + + std::optional getLoadCnt() const { return vmcnt; } + std::optional getStoreCnt() const { return vmcnt; } + std::optional getDsCnt() const { return lgkmcnt; } }; /// Lattice state tracking pending asynchronous operations @@ -239,8 +236,29 @@ class WaterInsertWaitcntPass return; } - // TODO: Use the analysis results to insert waitcnt instructions - // For now, just run the analysis to test the framework + // Insert waitcnt operations based on analysis results + IRRewriter rewriter(&getContext()); + op->walk([&](Operation *operation) { + // Query the state before this operation + const WaitcntState *state = solver.lookupState( + solver.getProgramPointAfter(operation)); + if (!state || !state->hasRequirement()) + return; + + const WaitcntRequirement &req = state->getRequirement(); + + auto getAttr = [&](std::optional cnt) -> IntegerAttr { + if (!cnt.has_value()) + return nullptr; + return rewriter.getI32IntegerAttr(*cnt); + }; + + // Insert wait operation before this operation + rewriter.setInsertionPoint(operation); + amdgpu::MemoryCounterWaitOp::create( + rewriter, operation->getLoc(), getAttr(req.getLoadCnt()), + getAttr(req.getStoreCnt()), getAttr(req.getDsCnt()), nullptr); + }); } }; From cecc76c820e8ec462299acdfc2ffed84cc1d5311 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 00:04:23 +0100 Subject: [PATCH 012/114] fix test Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 31 ++++++++++------- water/test/Transforms/insert-waitcnt.mlir | 38 +++------------------ 2 files changed, 24 insertions(+), 45 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 7fb7c07e0..71ce648d9 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -7,11 +7,15 @@ #include "water/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "water-insert-waitcnt" using namespace mlir; using namespace mlir::dataflow; @@ -129,9 +133,12 @@ class WaitcntState : public AbstractDenseLattice { size_t size() const { return pendingOps ? pendingOps->ops.size() : 0; } /// Initialize to empty state - void clear() { - pendingOps = std::make_shared(); - requirement = WaitcntRequirement(); + ChangeResult reset() { + if (!pendingOps) + return ChangeResult::NoChange; + + pendingOps.reset(); + return ChangeResult::Change; } /// Set the required waitcnt values @@ -183,17 +190,15 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; void setToEntryState(WaitcntState *lattice) override { - // Entry state: empty pending operations - lattice->clear(); + propagateIfChanged(lattice, lattice->reset()); } LogicalResult visitOperation(Operation *op, const WaitcntState &before, WaitcntState *after) override { - // Start with the state before this operation - *after = before; + LDBG() << "Visiting: " << *op; - // Always reset requirement - it should not propagate from previous state - after->setRequirement(WaitcntRequirement()); + // Start with the state before this operation + WaitcntState newState = before; // Check if any operands depend on pending operations WaitcntRequirement opRequirement; @@ -206,14 +211,15 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Set the requirement for this operation if (opRequirement.hasRequirement()) - after->setRequirement(opRequirement); + newState.setRequirement(opRequirement); // Check if this is an async memory operation (vector load/store) if (isa(op)) { // Add this operation to the pending list - after->addPendingOp(op); + newState.addPendingOp(op); } + propagateIfChanged(after, after->join(newState)); return success(); } }; @@ -228,6 +234,7 @@ class WaterInsertWaitcntPass // Set up the dataflow solver DataFlowSolver solver; + loadBaselineAnalyses(solver); solver.load(); // Run the analysis @@ -239,7 +246,7 @@ class WaterInsertWaitcntPass // Insert waitcnt operations based on analysis results IRRewriter rewriter(&getContext()); op->walk([&](Operation *operation) { - // Query the state before this operation + // Query the state after this operation const WaitcntState *state = solver.lookupState( solver.getProgramPointAfter(operation)); if (!state || !state->hasRequirement()) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index bae458f02..ed1bcfa0c 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -1,40 +1,12 @@ // RUN: water-opt %s --water-insert-waitcnt | FileCheck %s -// Smoke test to verify waitcnt insertion pass runs without crashing +// Test waitcnt insertion for vector memory operations -// CHECK-LABEL: func.func @simple_function -func.func @simple_function(%arg0: f32) -> f32 { - // CHECK: return %arg0 - return %arg0 : f32 -} - -// CHECK-LABEL: func.func @vector_load -func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { +// CHECK-LABEL: func.func @single_load_use +func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { // CHECK: vector.load %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: return + // CHECK: amdgpu.memory_counter_wait load(0) store(0) + // CHECK-NEXT: return return %result : vector<4xf32> } - -// CHECK-LABEL: func.func @vector_store -func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { - // CHECK: vector.store - vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: return - return -} - -// CHECK-LABEL: func.func @load_store_sequence -func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { - // Test a simple load followed by a store - // TODO: Eventually this should insert waitcnt between load and store - - // CHECK: vector.load - %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: vector.store - vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> - - // CHECK: return - return -} From bd348a52da6ad9c0f813d7b912ddef90f65a33d0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 01:37:41 +0100 Subject: [PATCH 013/114] fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 108 +++++++++++++++++--- water/test/Transforms/insert-waitcnt.mlir | 22 ++++ 2 files changed, 114 insertions(+), 16 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 71ce648d9..105c01693 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -29,6 +29,8 @@ namespace { /// Shared pending operations list for structural sharing struct PendingOperations { + PendingOperations() = default; + PendingOperations(SmallVector &&ops) : ops(std::move(ops)) {} SmallVector ops; }; @@ -66,8 +68,44 @@ struct WaitcntRequirement { std::optional getLoadCnt() const { return vmcnt; } std::optional getStoreCnt() const { return vmcnt; } std::optional getDsCnt() const { return lgkmcnt; } + + bool isSameCounterType(const WaitcntRequirement &other) const { + return vmcnt.has_value() == other.vmcnt.has_value() || + lgkmcnt.has_value() == other.lgkmcnt.has_value(); + } + + static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { + WaitcntRequirement req; + if (isa(op)) + req.vmcnt = zero ? 0 : 1; + return req; + } + + WaitcntRequirement operator+(const WaitcntRequirement &other) const { + WaitcntRequirement result; + if (vmcnt || other.vmcnt) + result.vmcnt = vmcnt.value_or(0) + other.vmcnt.value_or(0); + if (lgkmcnt || other.lgkmcnt) + result.lgkmcnt = lgkmcnt.value_or(0) + other.lgkmcnt.value_or(0); + return result; + } + + bool operator>(const WaitcntRequirement &other) const { + return vmcnt.value_or(0) > other.vmcnt.value_or(0) || + lgkmcnt.value_or(0) > other.lgkmcnt.value_or(0); + } + + void print(raw_ostream &os) const { + os << "WaitcntRequirement: vmcnt=" << vmcnt << " lgkmcnt=" << lgkmcnt; + } }; +inline raw_ostream &operator<<(raw_ostream &os, + const WaitcntRequirement &result) { + result.print(os); + return os; +} + /// Lattice state tracking pending asynchronous operations class WaitcntState : public AbstractDenseLattice { public: @@ -103,16 +141,14 @@ class WaitcntState : public AbstractDenseLattice { } void print(raw_ostream &os) const override { - os << "WaitcntState: " << size() << " pending ops"; + os << "WaitcntState: " << size() + << " pending ops, requirement: " << requirement; } /// Add a pending operation (copy-on-write) void addPendingOp(Operation *op) { - auto newPending = std::make_shared(); - if (pendingOps) - newPending->ops = pendingOps->ops; - newPending->ops.push_back(op); - pendingOps = newPending; + cow(); + pendingOps->ops.push_back(op); } /// Get pending operations (read-only) @@ -142,7 +178,27 @@ class WaitcntState : public AbstractDenseLattice { } /// Set the required waitcnt values - void setRequirement(const WaitcntRequirement &req) { requirement = req; } + void setRequirement(const WaitcntRequirement &req) { + requirement = req; + SmallVector newPending; + WaitcntRequirement runningRequirement; + for (Operation *op : llvm::reverse(pendingOps->ops)) { + WaitcntRequirement opReq = + WaitcntRequirement::getOperationRequirement(op, false); + runningRequirement = runningRequirement + opReq; + if (runningRequirement > requirement) + continue; + + newPending.push_back(op); + } + if (newPending.size() == pendingOps->ops.size()) + return; + + std::reverse(newPending.begin(), newPending.end()); + pendingOps = std::make_shared(std::move(newPending)); + } + + void resetRequirement() { requirement = {}; } /// Get the required waitcnt values const WaitcntRequirement &getRequirement() const { return requirement; } @@ -162,16 +218,18 @@ class WaitcntState : public AbstractDenseLattice { // Find the operation in the pending list auto it = llvm::find(pendingOps->ops, defOp); - if (it == pendingOps->ops.end()) + auto end = pendingOps->ops.end(); + if (it == end) return std::nullopt; - // Compute distance from the end of the list - size_t distanceFromEnd = std::distance(it, pendingOps->ops.end()) - 1; + auto req = WaitcntRequirement::getOperationRequirement(defOp, true); + for (Operation *op : llvm::make_range(std::next(it), end)) { + auto opReq = WaitcntRequirement::getOperationRequirement(op, false); + if (!req.isSameCounterType(opReq)) + continue; - WaitcntRequirement req; - // For vector loads/stores, use vmcnt - if (isa(defOp)) - req.vmcnt = distanceFromEnd; + req = req + opReq; + } return req; } @@ -182,6 +240,16 @@ class WaitcntState : public AbstractDenseLattice { /// Required waitcnt before this state (for inserting actual waitcnt ops) WaitcntRequirement requirement; + + void cow() { + if (!pendingOps || pendingOps.use_count() > 1) { + auto newPending = std::make_shared(); + if (pendingOps) + newPending->ops = pendingOps->ops; + + pendingOps = newPending; + } + } }; /// Dense forward dataflow analysis for waitcnt insertion @@ -210,15 +278,22 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } // Set the requirement for this operation - if (opRequirement.hasRequirement()) + if (opRequirement.hasRequirement()) { newState.setRequirement(opRequirement); + LDBG() << "Operation requirement: " << opRequirement; + } else { + newState.resetRequirement(); + LDBG() << "No operation requirement"; + } // Check if this is an async memory operation (vector load/store) - if (isa(op)) { + if (WaitcntRequirement::getOperationRequirement(op, false) + .hasRequirement()) { // Add this operation to the pending list newState.addPendingOp(op); } + LDBG() << "New state: " << newState; propagateIfChanged(after, after->join(newState)); return success(); } @@ -230,6 +305,7 @@ class WaterInsertWaitcntPass : public water::impl::WaterInsertWaitcntBase { public: void runOnOperation() override { + LDBG() << "Running WaterInsertWaitcntPass"; Operation *op = getOperation(); // Set up the dataflow solver diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index ed1bcfa0c..109ad6b78 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -10,3 +10,25 @@ func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector< // CHECK-NEXT: return return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @two_loads_use_in_reverse_order +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) +func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] + // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] + %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(1) store(1) + // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] + %addA = arith.addf %loadA, %loadA : vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) store(0) + // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] + %addB = arith.addf %loadB, %addA : vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + + // CHECK: return %[[ADD_B]] + return %addB : vector<4xf32> +} From b7a04426b66a917b76c0688d74fe8589ed73254b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 01:55:32 +0100 Subject: [PATCH 014/114] fix Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 105c01693..2161732f6 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -170,10 +170,11 @@ class WaitcntState : public AbstractDenseLattice { /// Initialize to empty state ChangeResult reset() { - if (!pendingOps) + if (!pendingOps && !requirement.hasRequirement()) return ChangeResult::NoChange; pendingOps.reset(); + requirement = {}; return ChangeResult::Change; } From 0b08cf362c64c14c44cb63dffa7d85750357fea0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 01:58:18 +0100 Subject: [PATCH 015/114] rename cnt Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 44 ++++++++++----------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 2161732f6..4c9800bbc 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -36,11 +36,11 @@ struct PendingOperations { /// Waitcnt requirement for synchronization struct WaitcntRequirement { - std::optional vmcnt; // Vector memory operations counter - std::optional lgkmcnt; // LDS/GDS operations counter + std::optional load_cnt; + std::optional ds_cnt; bool hasRequirement() const { - return vmcnt.has_value() || lgkmcnt.has_value(); + return load_cnt.has_value() || ds_cnt.has_value(); } /// Merge with another requirement (take minimum for conservative join) @@ -49,15 +49,15 @@ struct WaitcntRequirement { bool changed = false; // Take minimum of each counter (lower value = more restrictive) - if (other.vmcnt.has_value()) { - if (!vmcnt.has_value() || *other.vmcnt < *vmcnt) { - vmcnt = other.vmcnt; + if (other.load_cnt.has_value()) { + if (!load_cnt.has_value() || *other.load_cnt < *load_cnt) { + load_cnt = other.load_cnt; changed = true; } } - if (other.lgkmcnt.has_value()) { - if (!lgkmcnt.has_value() || *other.lgkmcnt < *lgkmcnt) { - lgkmcnt = other.lgkmcnt; + if (other.ds_cnt.has_value()) { + if (!ds_cnt.has_value() || *other.ds_cnt < *ds_cnt) { + ds_cnt = other.ds_cnt; changed = true; } } @@ -65,38 +65,38 @@ struct WaitcntRequirement { return changed; } - std::optional getLoadCnt() const { return vmcnt; } - std::optional getStoreCnt() const { return vmcnt; } - std::optional getDsCnt() const { return lgkmcnt; } + std::optional getLoadCnt() const { return load_cnt; } + std::optional getStoreCnt() const { return load_cnt; } + std::optional getDsCnt() const { return ds_cnt; } bool isSameCounterType(const WaitcntRequirement &other) const { - return vmcnt.has_value() == other.vmcnt.has_value() || - lgkmcnt.has_value() == other.lgkmcnt.has_value(); + return load_cnt.has_value() == other.load_cnt.has_value() || + ds_cnt.has_value() == other.ds_cnt.has_value(); } static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { WaitcntRequirement req; if (isa(op)) - req.vmcnt = zero ? 0 : 1; + req.load_cnt = zero ? 0 : 1; return req; } WaitcntRequirement operator+(const WaitcntRequirement &other) const { WaitcntRequirement result; - if (vmcnt || other.vmcnt) - result.vmcnt = vmcnt.value_or(0) + other.vmcnt.value_or(0); - if (lgkmcnt || other.lgkmcnt) - result.lgkmcnt = lgkmcnt.value_or(0) + other.lgkmcnt.value_or(0); + if (load_cnt || other.load_cnt) + result.load_cnt = load_cnt.value_or(0) + other.load_cnt.value_or(0); + if (ds_cnt || other.ds_cnt) + result.ds_cnt = ds_cnt.value_or(0) + other.ds_cnt.value_or(0); return result; } bool operator>(const WaitcntRequirement &other) const { - return vmcnt.value_or(0) > other.vmcnt.value_or(0) || - lgkmcnt.value_or(0) > other.lgkmcnt.value_or(0); + return load_cnt.value_or(0) > other.load_cnt.value_or(0) || + ds_cnt.value_or(0) > other.ds_cnt.value_or(0); } void print(raw_ostream &os) const { - os << "WaitcntRequirement: vmcnt=" << vmcnt << " lgkmcnt=" << lgkmcnt; + os << "WaitcntRequirement: load_cnt=" << load_cnt << " ds_cnt=" << ds_cnt; } }; From a6bd9eb1abeaa3e70c5b2401dcd0a06022ff6b77 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 02:00:35 +0100 Subject: [PATCH 016/114] fix Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 2 +- water/test/Transforms/insert-waitcnt.mlir | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 4c9800bbc..114b9d807 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -66,7 +66,7 @@ struct WaitcntRequirement { } std::optional getLoadCnt() const { return load_cnt; } - std::optional getStoreCnt() const { return load_cnt; } + std::optional getStoreCnt() const { return std::nullopt; } std::optional getDsCnt() const { return ds_cnt; } bool isSameCounterType(const WaitcntRequirement &other) const { diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 109ad6b78..b30e1f1ae 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -6,7 +6,7 @@ func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { // CHECK: vector.load %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: amdgpu.memory_counter_wait load(0) store(0) + // CHECK: amdgpu.memory_counter_wait load(0) // CHECK-NEXT: return return %result : vector<4xf32> } @@ -19,11 +19,11 @@ func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: amdgpu.memory_counter_wait load(1) store(1) + // CHECK: amdgpu.memory_counter_wait load(1) // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] %addA = arith.addf %loadA, %loadA : vector<4xf32> - // CHECK: amdgpu.memory_counter_wait load(0) store(0) + // CHECK: amdgpu.memory_counter_wait load(0) // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] %addB = arith.addf %loadB, %addA : vector<4xf32> From debaf1b0c99feebef24e40999e8ab001f44b3a4b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 11:31:19 +0100 Subject: [PATCH 017/114] addr space Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 39 +++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 114b9d807..96af227e6 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -9,6 +9,7 @@ #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -27,6 +28,34 @@ namespace mlir::water { namespace { +static std::optional isLoadOp(Operation *op) { + if (auto load = dyn_cast(op)) + return load.getBase(); + + return std::nullopt; +} + +static std::optional isStoreOp(Operation *op) { + if (auto store = dyn_cast(op)) + return store.getBase(); + + return std::nullopt; +} + +static std::optional isLoadOrStoreOp(Operation *op) { + if (auto load = isLoadOp(op)) + return load; + if (auto store = isStoreOp(op)) + return store; + + return std::nullopt; +} + +static bool isWorkgroupAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getValue() == gpu::AddressSpace::Workgroup; +} + /// Shared pending operations list for structural sharing struct PendingOperations { PendingOperations() = default; @@ -76,8 +105,14 @@ struct WaitcntRequirement { static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { WaitcntRequirement req; - if (isa(op)) - req.load_cnt = zero ? 0 : 1; + if (std::optional base = isLoadOrStoreOp(op)) { + auto memrefType = cast(base->getType()); + if (isWorkgroupAddressSpace(memrefType)) { + req.ds_cnt = zero ? 0 : 1; + } else { + req.load_cnt = zero ? 0 : 1; + } + } return req; } From 3a9ea1dc0b74dfcd1addaa9a62a8b0a1feed12b6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 12:00:27 +0100 Subject: [PATCH 018/114] mem deps Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 76 ++++++++++++++++++--- water/test/Transforms/insert-waitcnt.mlir | 48 +++++++++++++ 2 files changed, 116 insertions(+), 8 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 96af227e6..2508aafe8 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -252,14 +252,14 @@ class WaitcntState : public AbstractDenseLattice { if (!defOp) return std::nullopt; - // Find the operation in the pending list - auto it = llvm::find(pendingOps->ops, defOp); - auto end = pendingOps->ops.end(); - if (it == end) - return std::nullopt; - + // Search from the back to find the most recent dependency auto req = WaitcntRequirement::getOperationRequirement(defOp, true); - for (Operation *op : llvm::make_range(std::next(it), end)) { + bool found = false; + for (Operation *op : llvm::reverse(pendingOps->ops)) { + if (op == defOp) { + found = true; + break; + } auto opReq = WaitcntRequirement::getOperationRequirement(op, false); if (!req.isSameCounterType(opReq)) continue; @@ -267,9 +267,65 @@ class WaitcntState : public AbstractDenseLattice { req = req + opReq; } + if (!found) + return std::nullopt; + return req; } + /// Check for memory dependencies (RAW, WAR, WAW) + std::optional checkMemoryDependency(Operation *op) const { + if (!pendingOps || pendingOps->ops.empty()) + return std::nullopt; + + // Check if this is a load or store operation + std::optional currentBase = isLoadOrStoreOp(op); + if (!currentBase) + return std::nullopt; + + bool isCurrentLoad = isLoadOp(op).has_value(); + bool isCurrentStore = isStoreOp(op).has_value(); + + // Search from the back to find the most recent dependency + for (Operation *pendingOp : llvm::reverse(pendingOps->ops)) { + std::optional pendingBase = isLoadOrStoreOp(pendingOp); + if (!pendingBase) + continue; + + // Conservative aliasing check: same base memref means potential alias + if (*currentBase != *pendingBase) + continue; + + bool isPendingLoad = isLoadOp(pendingOp).has_value(); + bool isPendingStore = isStoreOp(pendingOp).has_value(); + + // Check for dependencies: + // RAW: current load after pending store + // WAR: current store after pending load + // WAW: current store after pending store + bool hasRAW = isCurrentLoad && isPendingStore; + bool hasWAR = isCurrentStore && isPendingLoad; + bool hasWAW = isCurrentStore && isPendingStore; + + if (hasRAW || hasWAR || hasWAW) { + // Found dependency - compute requirement by counting forward from here + auto it = llvm::find(pendingOps->ops, pendingOp); + auto req = WaitcntRequirement::getOperationRequirement(pendingOp, true); + for (Operation *countOp : + llvm::make_range(std::next(it), pendingOps->ops.end())) { + auto opReq = + WaitcntRequirement::getOperationRequirement(countOp, false); + if (!req.isSameCounterType(opReq)) + continue; + req = req + opReq; + } + return req; + } + } + + return std::nullopt; + } + private: /// Pending asynchronous operations (vector loads/stores) std::shared_ptr pendingOps; @@ -304,7 +360,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Start with the state before this operation WaitcntState newState = before; - // Check if any operands depend on pending operations + // Check if any operands depend on pending operations (value dependency) WaitcntRequirement opRequirement; for (Value operand : op->getOperands()) { if (auto req = before.checkRequirement(operand)) { @@ -313,6 +369,10 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } } + // Check for memory dependencies (RAW, WAR, WAW) + if (auto memReq = before.checkMemoryDependency(op)) + opRequirement.merge(*memReq); + // Set the requirement for this operation if (opRequirement.hasRequirement()) { newState.setRequirement(opRequirement); diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index b30e1f1ae..43bcb41ff 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -32,3 +32,51 @@ func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: // CHECK: return %[[ADD_B]] return %addB : vector<4xf32> } + +// CHECK-LABEL: func.func @raw_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @raw_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to memory + // CHECK: vector.store %[[DATA]], %[[MEM]] + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load from same memory - RAW dependency, must wait for store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]] + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @war_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @war_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Load from memory + // CHECK: %[[LOAD:.*]] = vector.load %[[MEM]] + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to same memory - WAR dependency, must wait for load + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store %[[DATA]], %[[MEM]] + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @waw_dependency +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA1:.*]]: vector<4xf32>, %[[DATA2:.*]]: vector<4xf32>, %{{.*}}: index) +func.func @waw_dependency(%memref: memref<1024xf32>, %data1: vector<4xf32>, %data2: vector<4xf32>, %offset: index) { + // First store + // CHECK: vector.store %[[DATA1]], %[[MEM]] + vector.store %data1, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Second store to same memory - WAW dependency, must wait for first store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store %[[DATA2]], %[[MEM]] + vector.store %data2, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return + return +} From eb89f51286717f27570ca9c5ec202b5e9fe620ef Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 12:09:00 +0100 Subject: [PATCH 019/114] AliasAnalysis Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 22 ++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 2508aafe8..e93f2e65c 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -6,6 +6,7 @@ #include "water/Transforms/Passes.h" +#include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" @@ -274,7 +275,8 @@ class WaitcntState : public AbstractDenseLattice { } /// Check for memory dependencies (RAW, WAR, WAW) - std::optional checkMemoryDependency(Operation *op) const { + std::optional + checkMemoryDependency(Operation *op, AliasAnalysis &aliasAnalysis) const { if (!pendingOps || pendingOps->ops.empty()) return std::nullopt; @@ -292,8 +294,7 @@ class WaitcntState : public AbstractDenseLattice { if (!pendingBase) continue; - // Conservative aliasing check: same base memref means potential alias - if (*currentBase != *pendingBase) + if (aliasAnalysis.alias(*currentBase, *pendingBase).isNo()) continue; bool isPendingLoad = isLoadOp(pendingOp).has_value(); @@ -347,7 +348,8 @@ class WaitcntState : public AbstractDenseLattice { /// Dense forward dataflow analysis for waitcnt insertion class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { public: - using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; + WaitcntAnalysis(DataFlowSolver &solver, AliasAnalysis &aliasAnalysis) + : DenseForwardDataFlowAnalysis(solver), aliasAnalysis(aliasAnalysis) {} void setToEntryState(WaitcntState *lattice) override { propagateIfChanged(lattice, lattice->reset()); @@ -370,7 +372,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } // Check for memory dependencies (RAW, WAR, WAW) - if (auto memReq = before.checkMemoryDependency(op)) + if (auto memReq = before.checkMemoryDependency(op, aliasAnalysis)) opRequirement.merge(*memReq); // Set the requirement for this operation @@ -393,6 +395,9 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { propagateIfChanged(after, after->join(newState)); return success(); } + +private: + AliasAnalysis &aliasAnalysis; }; /// Pass that inserts wait/synchronization instructions for asynchronous @@ -404,12 +409,12 @@ class WaterInsertWaitcntPass LDBG() << "Running WaterInsertWaitcntPass"; Operation *op = getOperation(); - // Set up the dataflow solver + auto &aliasAnalysis = getAnalysis(); + DataFlowSolver solver; loadBaselineAnalyses(solver); - solver.load(); + solver.load(aliasAnalysis); - // Run the analysis if (failed(solver.initializeAndRun(op))) { signalPassFailure(); return; @@ -418,7 +423,6 @@ class WaterInsertWaitcntPass // Insert waitcnt operations based on analysis results IRRewriter rewriter(&getContext()); op->walk([&](Operation *operation) { - // Query the state after this operation const WaitcntState *state = solver.lookupState( solver.getProgramPointAfter(operation)); if (!state || !state->hasRequirement()) From 1d879bb8c4a08e8575bcc1c43a4d263c1ff8cee8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 12:13:51 +0100 Subject: [PATCH 020/114] test Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 43bcb41ff..5a342303a 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -80,3 +80,28 @@ func.func @waw_dependency(%memref: memref<1024xf32>, %data1: vector<4xf32>, %dat // CHECK: return return } + +// CHECK-LABEL: func.func @raw_dependency_non_zero_waitcnt +func.func @raw_dependency_non_zero_waitcnt(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate two distinct memrefs to guarantee no aliasing + // CHECK: %[[MEM_A:.*]] = memref.alloc() + %memrefA = memref.alloc() : memref<1024xf32> + // CHECK: %[[MEM_B:.*]] = memref.alloc() + %memrefB = memref.alloc() : memref<1024xf32> + + // Store to memory A + // CHECK: vector.store %{{.*}}, %[[MEM_A]] + vector.store %data, %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to memory B (intervening operation, different memref) + // CHECK: vector.store %{{.*}}, %[[MEM_B]] + vector.store %data, %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // Load from memory A - RAW dependency with store to A at distance 1 + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM_A]] + %result = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: return %[[LOAD]] + return %result : vector<4xf32> +} From 6c076124c75409e8c15329328d07c17b41f714d3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 12:22:17 +0100 Subject: [PATCH 021/114] update print Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index e93f2e65c..392af6466 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -177,8 +177,12 @@ class WaitcntState : public AbstractDenseLattice { } void print(raw_ostream &os) const override { - os << "WaitcntState: " << size() - << " pending ops, requirement: " << requirement; + os << "WaitcntState: " << size() << " pending ops ["; + if (pendingOps) { + llvm::interleaveComma(pendingOps->ops, os, + [&](Operation *op) { os << *op; }); + } + os << "], requirement: " << requirement; } /// Add a pending operation (copy-on-write) From 2e6fd0af1755eb198f9e402be3df34018e68300a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 13:27:02 +0100 Subject: [PATCH 022/114] ops lists Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 247 +++++++++++--------- 1 file changed, 137 insertions(+), 110 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 392af6466..cd49c0e7e 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -61,6 +61,19 @@ static bool isWorkgroupAddressSpace(MemRefType type) { struct PendingOperations { PendingOperations() = default; PendingOperations(SmallVector &&ops) : ops(std::move(ops)) {} + + size_t size() const { return ops.size(); } + bool empty() const { return ops.empty(); } + + bool hasSameTail(const PendingOperations &other) const { + for (const auto &[op1, op2] : + llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops))) { + if (op1 != op2) + return false; + } + return true; + } + SmallVector ops; }; @@ -153,20 +166,30 @@ class WaitcntState : public AbstractDenseLattice { const auto &rhsState = static_cast(rhs); bool changed = false; - // Merge pending operations - if (!rhsState.isEmpty()) { - if (isEmpty()) { - pendingOps = rhsState.pendingOps; - changed = true; - } else { - // Conservative union: merge all pending operations - for (Operation *op : rhsState.getPendingOps()) { - if (!contains(op)) { - addPendingOp(op); + SmallVector, 4> toAppend; + // Check if any pending operations has the same subset of operations as the + // rhs and take the longer one. + for (auto &rhsPendingOps : rhsState.pendingOpsLists) { + bool found = false; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->hasSameTail(*rhsPendingOps)) { + if (rhsPendingOps->size() > pendingOps->size()) { + pendingOps = rhsPendingOps; changed = true; } + found = true; + break; } } + if (!found) + toAppend.push_back(rhsPendingOps); + } + + // If there are any pending operations that don't have the same subset of + // operations as the rhs, append them to the pending operations lists. + if (!toAppend.empty()) { + pendingOpsLists.append(toAppend); + changed = true; } // Merge requirements (take minimum for conservative join) @@ -177,43 +200,32 @@ class WaitcntState : public AbstractDenseLattice { } void print(raw_ostream &os) const override { - os << "WaitcntState: " << size() << " pending ops ["; - if (pendingOps) { + os << "WaitcntState: pending ops ["; + for (auto &pendingOps : pendingOpsLists) { + os << "["; llvm::interleaveComma(pendingOps->ops, os, [&](Operation *op) { os << *op; }); + os << "]"; } os << "], requirement: " << requirement; } - /// Add a pending operation (copy-on-write) void addPendingOp(Operation *op) { - cow(); - pendingOps->ops.push_back(op); - } - - /// Get pending operations (read-only) - ArrayRef getPendingOps() const { - return pendingOps ? ArrayRef(pendingOps->ops) - : ArrayRef(); - } - - /// Check if this operation is in the pending list - bool contains(Operation *op) const { - return pendingOps && llvm::is_contained(pendingOps->ops, op); + if (pendingOpsLists.empty()) { + pendingOpsLists.push_back(std::make_shared()); + } else { + cow(); + } + for (auto &pendingOps : pendingOpsLists) + pendingOps->ops.push_back(op); } - /// Check if empty - bool isEmpty() const { return !pendingOps || pendingOps->ops.empty(); } - - /// Get size - size_t size() const { return pendingOps ? pendingOps->ops.size() : 0; } - /// Initialize to empty state ChangeResult reset() { - if (!pendingOps && !requirement.hasRequirement()) + if (pendingOpsLists.empty() && !requirement.hasRequirement()) return ChangeResult::NoChange; - pendingOps.reset(); + pendingOpsLists.clear(); requirement = {}; return ChangeResult::Change; } @@ -221,22 +233,24 @@ class WaitcntState : public AbstractDenseLattice { /// Set the required waitcnt values void setRequirement(const WaitcntRequirement &req) { requirement = req; - SmallVector newPending; - WaitcntRequirement runningRequirement; - for (Operation *op : llvm::reverse(pendingOps->ops)) { - WaitcntRequirement opReq = - WaitcntRequirement::getOperationRequirement(op, false); - runningRequirement = runningRequirement + opReq; - if (runningRequirement > requirement) - continue; + for (auto &pendingOps : pendingOpsLists) { + SmallVector newPending; + WaitcntRequirement runningRequirement; + for (Operation *op : llvm::reverse(pendingOps->ops)) { + WaitcntRequirement opReq = + WaitcntRequirement::getOperationRequirement(op, false); + runningRequirement = runningRequirement + opReq; + if (runningRequirement > requirement) + continue; + + newPending.push_back(op); + } + if (newPending.size() == pendingOps->ops.size()) + return; - newPending.push_back(op); + std::reverse(newPending.begin(), newPending.end()); + pendingOps = std::make_shared(std::move(newPending)); } - if (newPending.size() == pendingOps->ops.size()) - return; - - std::reverse(newPending.begin(), newPending.end()); - pendingOps = std::make_shared(std::move(newPending)); } void resetRequirement() { requirement = {}; } @@ -249,41 +263,41 @@ class WaitcntState : public AbstractDenseLattice { /// Check if a value depends on pending operations and compute required wait std::optional checkRequirement(Value val) const { - if (!pendingOps || pendingOps->ops.empty()) - return std::nullopt; - // Check if val is produced by any pending operation Operation *defOp = val.getDefiningOp(); if (!defOp) return std::nullopt; - // Search from the back to find the most recent dependency - auto req = WaitcntRequirement::getOperationRequirement(defOp, true); - bool found = false; - for (Operation *op : llvm::reverse(pendingOps->ops)) { - if (op == defOp) { - found = true; - break; - } - auto opReq = WaitcntRequirement::getOperationRequirement(op, false); - if (!req.isSameCounterType(opReq)) + WaitcntRequirement result; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) continue; - req = req + opReq; + // Search from the back to find the most recent dependency + auto req = WaitcntRequirement::getOperationRequirement(defOp, true); + for (Operation *op : llvm::reverse(pendingOps->ops)) { + if (op == defOp) + break; + + auto opReq = WaitcntRequirement::getOperationRequirement(op, false); + if (!req.isSameCounterType(opReq)) + continue; + + req = req + opReq; + } + + result.merge(req); } - if (!found) + if (!result.hasRequirement()) return std::nullopt; - return req; + return result; } /// Check for memory dependencies (RAW, WAR, WAW) std::optional checkMemoryDependency(Operation *op, AliasAnalysis &aliasAnalysis) const { - if (!pendingOps || pendingOps->ops.empty()) - return std::nullopt; - // Check if this is a load or store operation std::optional currentBase = isLoadOrStoreOp(op); if (!currentBase) @@ -292,59 +306,71 @@ class WaitcntState : public AbstractDenseLattice { bool isCurrentLoad = isLoadOp(op).has_value(); bool isCurrentStore = isStoreOp(op).has_value(); - // Search from the back to find the most recent dependency - for (Operation *pendingOp : llvm::reverse(pendingOps->ops)) { - std::optional pendingBase = isLoadOrStoreOp(pendingOp); - if (!pendingBase) - continue; - - if (aliasAnalysis.alias(*currentBase, *pendingBase).isNo()) + WaitcntRequirement result; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) continue; - bool isPendingLoad = isLoadOp(pendingOp).has_value(); - bool isPendingStore = isStoreOp(pendingOp).has_value(); - - // Check for dependencies: - // RAW: current load after pending store - // WAR: current store after pending load - // WAW: current store after pending store - bool hasRAW = isCurrentLoad && isPendingStore; - bool hasWAR = isCurrentStore && isPendingLoad; - bool hasWAW = isCurrentStore && isPendingStore; - - if (hasRAW || hasWAR || hasWAW) { - // Found dependency - compute requirement by counting forward from here - auto it = llvm::find(pendingOps->ops, pendingOp); - auto req = WaitcntRequirement::getOperationRequirement(pendingOp, true); - for (Operation *countOp : - llvm::make_range(std::next(it), pendingOps->ops.end())) { - auto opReq = - WaitcntRequirement::getOperationRequirement(countOp, false); - if (!req.isSameCounterType(opReq)) - continue; - req = req + opReq; + // Search from the back to find the most recent dependency + for (Operation *pendingOp : llvm::reverse(pendingOps->ops)) { + std::optional pendingBase = isLoadOrStoreOp(pendingOp); + if (!pendingBase) + continue; + + if (aliasAnalysis.alias(*currentBase, *pendingBase).isNo()) + continue; + + bool isPendingLoad = isLoadOp(pendingOp).has_value(); + bool isPendingStore = isStoreOp(pendingOp).has_value(); + + // Check for dependencies: + // RAW: current load after pending store + // WAR: current store after pending load + // WAW: current store after pending store + bool hasRAW = isCurrentLoad && isPendingStore; + bool hasWAR = isCurrentStore && isPendingLoad; + bool hasWAW = isCurrentStore && isPendingStore; + + if (hasRAW || hasWAR || hasWAW) { + // Found dependency - compute requirement by counting forward from + // here + auto it = llvm::find(pendingOps->ops, pendingOp); + auto req = + WaitcntRequirement::getOperationRequirement(pendingOp, true); + for (Operation *countOp : + llvm::make_range(std::next(it), pendingOps->ops.end())) { + auto opReq = + WaitcntRequirement::getOperationRequirement(countOp, false); + if (!req.isSameCounterType(opReq)) + continue; + req = req + opReq; + } + result.merge(req); } - return req; } } - return std::nullopt; + if (!result.hasRequirement()) + return std::nullopt; + + return result; } private: - /// Pending asynchronous operations (vector loads/stores) - std::shared_ptr pendingOps; + /// Pending asynchronous operations + SmallVector, 4> pendingOpsLists; - /// Required waitcnt before this state (for inserting actual waitcnt ops) + /// Required waitcnt after this state WaitcntRequirement requirement; void cow() { - if (!pendingOps || pendingOps.use_count() > 1) { - auto newPending = std::make_shared(); - if (pendingOps) - newPending->ops = pendingOps->ops; - - pendingOps = newPending; + for (auto &pendingOps : pendingOpsLists) { + if (!pendingOps || pendingOps.use_count() > 1) { + auto newPending = std::make_shared(); + if (pendingOps) + newPending->ops = pendingOps->ops; + pendingOps = newPending; + } } } }; @@ -362,6 +388,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { LogicalResult visitOperation(Operation *op, const WaitcntState &before, WaitcntState *after) override { LDBG() << "Visiting: " << *op; + LDBG() << " Before: " << before; // Start with the state before this operation WaitcntState newState = before; @@ -382,10 +409,10 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Set the requirement for this operation if (opRequirement.hasRequirement()) { newState.setRequirement(opRequirement); - LDBG() << "Operation requirement: " << opRequirement; + LDBG() << " Operation requirement: " << opRequirement; } else { newState.resetRequirement(); - LDBG() << "No operation requirement"; + LDBG() << " No operation requirement"; } // Check if this is an async memory operation (vector load/store) @@ -395,7 +422,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { newState.addPendingOp(op); } - LDBG() << "New state: " << newState; + LDBG() << " New state: " << newState; propagateIfChanged(after, after->join(newState)); return success(); } From ac855496e188dd6c2b587e106ab211dae8d6e9d2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 13:48:31 +0100 Subject: [PATCH 023/114] fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 12 +++++++++--- water/test/Transforms/insert-waitcnt.mlir | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index cd49c0e7e..36a0b42f7 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -274,10 +274,13 @@ class WaitcntState : public AbstractDenseLattice { continue; // Search from the back to find the most recent dependency + bool found = false; auto req = WaitcntRequirement::getOperationRequirement(defOp, true); for (Operation *op : llvm::reverse(pendingOps->ops)) { - if (op == defOp) + if (op == defOp) { + found = true; break; + } auto opReq = WaitcntRequirement::getOperationRequirement(op, false); if (!req.isSameCounterType(opReq)) @@ -286,7 +289,8 @@ class WaitcntState : public AbstractDenseLattice { req = req + opReq; } - result.merge(req); + if (found) + result.merge(req); } if (!result.hasRequirement()) @@ -403,8 +407,10 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } // Check for memory dependencies (RAW, WAR, WAW) - if (auto memReq = before.checkMemoryDependency(op, aliasAnalysis)) + if (auto memReq = before.checkMemoryDependency(op, aliasAnalysis)) { + LDBG() << " Memory dependency: " << *memReq; opRequirement.merge(*memReq); + } // Set the requirement for this operation if (opRequirement.hasRequirement()) { diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 5a342303a..49e71a342 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -1,7 +1,5 @@ // RUN: water-opt %s --water-insert-waitcnt | FileCheck %s -// Test waitcnt insertion for vector memory operations - // CHECK-LABEL: func.func @single_load_use func.func @single_load_use(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { // CHECK: vector.load @@ -61,6 +59,7 @@ func.func @war_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offs // CHECK-NEXT: vector.store %[[DATA]], %[[MEM]] vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK-NOT: amdgpu.memory_counter_wait // CHECK: return %[[LOAD]] return %result : vector<4xf32> } @@ -102,6 +101,7 @@ func.func @raw_dependency_non_zero_waitcnt(%data: vector<4xf32>, %offset: index) // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM_A]] %result = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait load(0) // CHECK: return %[[LOAD]] return %result : vector<4xf32> } From b13b83a2ff2b43bd196512c2d0e9a9af7dddcbf2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 13:50:08 +0100 Subject: [PATCH 024/114] refac Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 23 ++++++++------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 36a0b42f7..7c92fcab6 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -143,6 +143,7 @@ struct WaitcntRequirement { return load_cnt.value_or(0) > other.load_cnt.value_or(0) || ds_cnt.value_or(0) > other.ds_cnt.value_or(0); } + operator bool() const { return hasRequirement(); } void print(raw_ostream &os) const { os << "WaitcntRequirement: load_cnt=" << load_cnt << " ds_cnt=" << ds_cnt; @@ -262,11 +263,11 @@ class WaitcntState : public AbstractDenseLattice { bool hasRequirement() const { return requirement.hasRequirement(); } /// Check if a value depends on pending operations and compute required wait - std::optional checkRequirement(Value val) const { + WaitcntRequirement checkRequirement(Value val) const { // Check if val is produced by any pending operation Operation *defOp = val.getDefiningOp(); if (!defOp) - return std::nullopt; + return {}; WaitcntRequirement result; for (auto &pendingOps : pendingOpsLists) { @@ -293,19 +294,16 @@ class WaitcntState : public AbstractDenseLattice { result.merge(req); } - if (!result.hasRequirement()) - return std::nullopt; - return result; } /// Check for memory dependencies (RAW, WAR, WAW) - std::optional - checkMemoryDependency(Operation *op, AliasAnalysis &aliasAnalysis) const { + WaitcntRequirement checkMemoryDependency(Operation *op, + AliasAnalysis &aliasAnalysis) const { // Check if this is a load or store operation std::optional currentBase = isLoadOrStoreOp(op); if (!currentBase) - return std::nullopt; + return {}; bool isCurrentLoad = isLoadOp(op).has_value(); bool isCurrentStore = isStoreOp(op).has_value(); @@ -354,9 +352,6 @@ class WaitcntState : public AbstractDenseLattice { } } - if (!result.hasRequirement()) - return std::nullopt; - return result; } @@ -402,14 +397,14 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { for (Value operand : op->getOperands()) { if (auto req = before.checkRequirement(operand)) { // Merge this requirement (take minimum for conservative wait) - opRequirement.merge(*req); + opRequirement.merge(req); } } // Check for memory dependencies (RAW, WAR, WAW) if (auto memReq = before.checkMemoryDependency(op, aliasAnalysis)) { - LDBG() << " Memory dependency: " << *memReq; - opRequirement.merge(*memReq); + LDBG() << " Memory dependency: " << memReq; + opRequirement.merge(memReq); } // Set the requirement for this operation From 3a94af6b775016d440a8cd526255e32ddb627ecb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 13:56:49 +0100 Subject: [PATCH 025/114] LDS tests Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 48 +++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 49e71a342..ae7319bf2 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -105,3 +105,51 @@ func.func @raw_dependency_non_zero_waitcnt(%data: vector<4xf32>, %offset: index) // CHECK: return %[[LOAD]] return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @workgroup_memory_raw +func.func @workgroup_memory_raw(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate workgroup (LDS) memory + // CHECK: %[[LDS:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + %lds = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // Store to LDS + // CHECK: vector.store %{{.*}}, %[[LDS]] + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from LDS - RAW dependency, should use dsCnt not loadCnt + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[LDS]] + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @mixed_global_and_workgroup +func.func @mixed_global_and_workgroup(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Allocate global and workgroup memory + // CHECK: %[[GLOBAL:.*]] = memref.alloc() : memref<1024xf32> + %global = memref.alloc() : memref<1024xf32> + // CHECK: %[[LDS:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + %lds = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // Store to global memory + // CHECK: vector.store %{{.*}}, %[[GLOBAL]] + vector.store %data, %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to LDS (different counter, no dependency) + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store %{{.*}}, %[[LDS]] + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from global - RAW dependency with global store at distance 0 + // (LDS store doesn't count because it's a different counter type) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[GLOBAL]] + %result = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} From 24606592323751c9ea30ad6b9f88bf5f11fb652f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 14:13:44 +0100 Subject: [PATCH 026/114] branch tests Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 10 ++- water/test/Transforms/insert-waitcnt.mlir | 79 +++++++++++++++++++-- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 7c92fcab6..657b31b2b 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -156,6 +156,14 @@ inline raw_ostream &operator<<(raw_ostream &os, return os; } +static bool mayAlias(Value lhs, Value rhs, AliasAnalysis &aliasAnalysis) { + if (isWorkgroupAddressSpace(cast(lhs.getType())) != + isWorkgroupAddressSpace(cast(rhs.getType()))) + return false; + + return !aliasAnalysis.alias(lhs, rhs).isNo(); +} + /// Lattice state tracking pending asynchronous operations class WaitcntState : public AbstractDenseLattice { public: @@ -319,7 +327,7 @@ class WaitcntState : public AbstractDenseLattice { if (!pendingBase) continue; - if (aliasAnalysis.alias(*currentBase, *pendingBase).isNo()) + if (!mayAlias(*currentBase, *pendingBase, aliasAnalysis)) continue; bool isPendingLoad = isLoadOp(pendingOp).has_value(); diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index ae7319bf2..02e3e8bf5 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -127,13 +127,8 @@ func.func @workgroup_memory_raw(%data: vector<4xf32>, %offset: index) -> vector< } // CHECK-LABEL: func.func @mixed_global_and_workgroup -func.func @mixed_global_and_workgroup(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { - // Allocate global and workgroup memory - // CHECK: %[[GLOBAL:.*]] = memref.alloc() : memref<1024xf32> - %global = memref.alloc() : memref<1024xf32> - // CHECK: %[[LDS:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> - %lds = memref.alloc() : memref<1024xf32, #gpu.address_space> - +// CHECK-SAME: (%[[GLOBAL:.*]]: memref<1024xf32>, %[[LDS:.*]]: memref<1024xf32, #gpu.address_space>, %{{.*}}: vector<4xf32>, %{{.*}}: index) +func.func @mixed_global_and_workgroup(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { // Store to global memory // CHECK: vector.store %{{.*}}, %[[GLOBAL]] vector.store %data, %global[%offset] : memref<1024xf32>, vector<4xf32> @@ -153,3 +148,73 @@ func.func @mixed_global_and_workgroup(%data: vector<4xf32>, %offset: index) -> v // CHECK-NEXT: return %[[LOAD]] return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @control_flow_merge +func.func @control_flow_merge(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Common operation before branching + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // Extra operation in this path + // CHECK: vector.store + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb2: + // No extra operations, just branch to merge point + // CHECK: cf.br + cf.br ^bb3 + +^bb3: + // bb1 branch has distance 1 but bb2 has distance 0, so we need to conservatively + // take 0 + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @control_flow_merge_same_lists +func.func @control_flow_merge_same_lists(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Common operation before branching + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: vector.store + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb2: + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + // CHECK: cf.br + cf.br ^bb3 + +^bb3: + // both branches has the same distance 1 + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} From 10cb6c3c36e46976b5b9efae92b34cf2e13cc370 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 14:24:12 +0100 Subject: [PATCH 027/114] lists bookkeeping Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 26 +++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 657b31b2b..07ab61590 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -255,11 +255,33 @@ class WaitcntState : public AbstractDenseLattice { newPending.push_back(op); } if (newPending.size() == pendingOps->ops.size()) - return; + continue; std::reverse(newPending.begin(), newPending.end()); pendingOps = std::make_shared(std::move(newPending)); } + + // Remove empty lists + pendingOpsLists.erase(std::remove_if(pendingOpsLists.begin(), + pendingOpsLists.end(), + [](const auto &pendingOps) { + return pendingOps->empty(); + }), + pendingOpsLists.end()); + + // Merge lists with the same tail (keep the longer one) + for (size_t i = 0; i < pendingOpsLists.size(); ++i) { + for (size_t j = i + 1; j < pendingOpsLists.size();) { + if (pendingOpsLists[i]->hasSameTail(*pendingOpsLists[j])) { + if (pendingOpsLists[j]->size() > pendingOpsLists[i]->size()) { + pendingOpsLists[i] = pendingOpsLists[j]; + } + pendingOpsLists.erase(pendingOpsLists.begin() + j); + } else { + ++j; + } + } + } } void resetRequirement() { requirement = {}; } @@ -372,7 +394,7 @@ class WaitcntState : public AbstractDenseLattice { void cow() { for (auto &pendingOps : pendingOpsLists) { - if (!pendingOps || pendingOps.use_count() > 1) { + if (pendingOps.use_count() > 1) { auto newPending = std::make_shared(); if (pendingOps) newPending->ops = pendingOps->ops; From 079d4443b50ac4ec2d5afbd520442bd6ddca391c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 14:32:13 +0100 Subject: [PATCH 028/114] pending ops set Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 36 +++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 07ab61590..8eb4c6313 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -201,6 +201,9 @@ class WaitcntState : public AbstractDenseLattice { changed = true; } + if (changed) + resetPendingOpsSet(); + // Merge requirements (take minimum for conservative join) if (requirement.merge(rhsState.requirement)) changed = true; @@ -227,6 +230,8 @@ class WaitcntState : public AbstractDenseLattice { } for (auto &pendingOps : pendingOpsLists) pendingOps->ops.push_back(op); + + pendingOpsSet.insert(op); } /// Initialize to empty state @@ -236,6 +241,7 @@ class WaitcntState : public AbstractDenseLattice { pendingOpsLists.clear(); requirement = {}; + resetPendingOpsSet(); return ChangeResult::Change; } @@ -282,6 +288,8 @@ class WaitcntState : public AbstractDenseLattice { } } } + + resetPendingOpsSet(); } void resetRequirement() { requirement = {}; } @@ -299,6 +307,9 @@ class WaitcntState : public AbstractDenseLattice { if (!defOp) return {}; + if (!isPendingOp(defOp)) + return {}; + WaitcntRequirement result; for (auto &pendingOps : pendingOpsLists) { if (pendingOps->empty()) @@ -392,6 +403,8 @@ class WaitcntState : public AbstractDenseLattice { /// Required waitcnt after this state WaitcntRequirement requirement; + mutable llvm::SmallDenseSet pendingOpsSet; + void cow() { for (auto &pendingOps : pendingOpsLists) { if (pendingOps.use_count() > 1) { @@ -402,6 +415,29 @@ class WaitcntState : public AbstractDenseLattice { } } } + + bool isPendingOp(Operation *op) const { + if (pendingOpsLists.empty()) + return false; + + // Build the set of pending operations lazily + bool found = false; + if (pendingOpsSet.empty()) { + for (const auto &pendingOps : pendingOpsLists) { + for (Operation *pendingOp : pendingOps->ops) { + if (pendingOp == op) + found = true; + + pendingOpsSet.insert(pendingOp); + } + } + return found; + } + + return pendingOpsSet.contains(op); + } + + void resetPendingOpsSet() { pendingOpsSet.clear(); } }; /// Dense forward dataflow analysis for waitcnt insertion From 7d2346039a907973a283db1e58e66091ec443ec9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 15:01:13 +0100 Subject: [PATCH 029/114] explicit waits Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 19 +++++++- water/test/Transforms/insert-waitcnt.mlir | 52 +++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 8eb4c6313..7f7e16b1d 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -82,6 +82,15 @@ struct WaitcntRequirement { std::optional load_cnt; std::optional ds_cnt; + WaitcntRequirement() = default; + + WaitcntRequirement(amdgpu::MemoryCounterWaitOp waitOp) { + if (auto loadCnt = waitOp.getLoadAttr()) + load_cnt = loadCnt.getInt(); + if (auto dsCnt = waitOp.getDsAttr()) + ds_cnt = dsCnt.getInt(); + } + bool hasRequirement() const { return load_cnt.has_value() || ds_cnt.has_value(); } @@ -473,6 +482,12 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { opRequirement.merge(memReq); } + // Check if this is an existing memory_counter_wait operation + if (auto waitOp = dyn_cast(op)) { + LDBG() << " Existing waitcnt operation: " << *waitOp; + opRequirement.merge(WaitcntRequirement(waitOp)); + } + // Set the requirement for this operation if (opRequirement.hasRequirement()) { newState.setRequirement(opRequirement); @@ -534,7 +549,9 @@ class WaterInsertWaitcntPass return rewriter.getI32IntegerAttr(*cnt); }; - // Insert wait operation before this operation + // Insert wait operation before the current operation. + // If the current operation is already a memory_counter_wait operation + // they will be merged later. rewriter.setInsertionPoint(operation); amdgpu::MemoryCounterWaitOp::create( rewriter, operation->getLoc(), getAttr(req.getLoadCnt()), diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 02e3e8bf5..dded5f2b4 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -149,6 +149,58 @@ func.func @mixed_global_and_workgroup(%global: memref<1024xf32>, %lds: memref<10 return %result : vector<4xf32> } +// CHECK-LABEL: func.func @existing_waitcnt +func.func @existing_waitcnt(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to memory + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Existing wait operation - should clear pending operations + // CHECK: amdgpu.memory_counter_wait load(0) + amdgpu.memory_counter_wait load(0) + + // Another store after the wait + // CHECK: vector.store + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load requires wait for the second store only (first was already waited on) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = vector.load + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @existing_waitcnt_more_strict +func.func @existing_waitcnt_more_strict(%data: vector<4xf32>, %offset: index) -> vector<4xf32> { + %memref1 = memref.alloc() : memref<1024xf32> + %memref2 = memref.alloc() : memref<1024xf32> + + // Store to memory + // CHECK: vector.store + // CHECK: vector.store + vector.store %data, %memref1[%offset] : memref<1024xf32>, vector<4xf32> + vector.store %data, %memref2[%offset] : memref<1024xf32>, vector<4xf32> + + // Existing wait operation - should clear pending operations + // Normally, the distance will be 1, but explicit amdgpu.memory_counter_wait + // overrides it. + // CHECK-NOT: amdgpu.memory_counter_wait load(1) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NOT: amdgpu.memory_counter_wait load(1) + amdgpu.memory_counter_wait load(0) + + // CHECK: %[[LOAD:.*]] = vector.load + %result = vector.load %memref1[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[LOAD]] + return %result : vector<4xf32> +} + + // CHECK-LABEL: func.func @control_flow_merge func.func @control_flow_merge(%cond: i1, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { %memref1 = memref.alloc() : memref<1024xf32> From 5c8262091b260a0b33361a03beb61ff75059fdec Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 15:13:05 +0100 Subject: [PATCH 030/114] tests Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 50 +++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index dded5f2b4..15419b9c5 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -270,3 +270,53 @@ func.func @control_flow_merge_same_lists(%cond: i1, %data: vector<4xf32>, %offse // CHECK-NEXT: return %[[LOAD]] return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @loop_carried_dependency +func.func @loop_carried_dependency(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // CHECK: scf.for + %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { + // Store in each iteration + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Load in the same iteration - RAW dependency with store from this iteration + // In steady state, the backedge brings pending operations from previous iteration + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOADED:.*]] = vector.load + %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Yield uses the load result, which is async, so need to wait for it + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: scf.yield %[[LOADED]] + scf.yield %loaded : vector<4xf32> + } + + // CHECK: return + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @loop_load_before_store +func.func @loop_load_before_store(%lb: index, %ub: index, %step: index, %memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // CHECK: scf.for + %result = scf.for %i = %lb to %ub step %step iter_args(%arg = %data) -> (vector<4xf32>) { + // Load first - in steady state, has RAW dependency with store from previous iteration + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOADED:.*]] = vector.load + %loaded = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Store after load - WAR dependency with load in same iteration + // The wait for the load clears it from pending, so this wait is for the load + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.store + vector.store %arg, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Yield uses load result - load was already waited on by the store, no additional wait needed + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield %[[LOADED]] + scf.yield %loaded : vector<4xf32> + } + + // CHECK: return + return %result : vector<4xf32> +} From 39a2f8eae8d4c431901b8086f618b69c17346c36 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 15:24:36 +0100 Subject: [PATCH 031/114] memref.load/store Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 5 +++++ water/test/Transforms/insert-waitcnt.mlir | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 7f7e16b1d..6a53c2f3a 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -11,6 +11,7 @@ #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -32,6 +33,8 @@ namespace { static std::optional isLoadOp(Operation *op) { if (auto load = dyn_cast(op)) return load.getBase(); + if (auto load = dyn_cast(op)) + return load.getMemRef(); return std::nullopt; } @@ -39,6 +42,8 @@ static std::optional isLoadOp(Operation *op) { static std::optional isStoreOp(Operation *op) { if (auto store = dyn_cast(op)) return store.getBase(); + if (auto store = dyn_cast(op)) + return store.getMemRef(); return std::nullopt; } diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 15419b9c5..7a389bad0 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -47,6 +47,22 @@ func.func @raw_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offs return %result : vector<4xf32> } +// CHECK-LABEL: func.func @raw_dependency_memref +// CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: f32, %{{.*}}: index) +func.func @raw_dependency_memref(%memref: memref<1024xf32>, %data: f32, %offset: index) -> f32 { + // Store to memory + // CHECK: memref.store %[[DATA]], %[[MEM]] + memref.store %data, %memref[%offset] : memref<1024xf32> + + // Load from same memory - RAW dependency, must wait for store + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]] + %result = memref.load %memref[%offset] : memref<1024xf32> + + // CHECK: return %[[LOAD]] + return %result : f32 +} + // CHECK-LABEL: func.func @war_dependency // CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) func.func @war_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { From caa9980fb3c07e0102d3ba8cfe4540d0b24806c6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 19:28:06 +0100 Subject: [PATCH 032/114] copy WIP Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 6a53c2f3a..6627e1eaa 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -30,29 +30,39 @@ namespace mlir::water { namespace { +/// Check if the operation is a load operation and return the base memref. static std::optional isLoadOp(Operation *op) { + // TODO: replace with the interface when available. if (auto load = dyn_cast(op)) return load.getBase(); if (auto load = dyn_cast(op)) return load.getMemRef(); + if (auto copy = dyn_cast(op)) + return copy.getSource(); return std::nullopt; } +/// Check if the operation is a store operation and return the base memref. static std::optional isStoreOp(Operation *op) { + // TODO: replace with the interface when available. if (auto store = dyn_cast(op)) return store.getBase(); if (auto store = dyn_cast(op)) return store.getMemRef(); + if (auto copy = dyn_cast(op)) + return copy.getTarget(); return std::nullopt; } +/// Check if the operation is a load or store operation and return the base +/// memref. static std::optional isLoadOrStoreOp(Operation *op) { - if (auto load = isLoadOp(op)) - return load; if (auto store = isStoreOp(op)) return store; + if (auto load = isLoadOp(op)) + return load; return std::nullopt; } From 8a633574e9741513e6cec9aca9c881a1afa832a5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 21:36:28 +0100 Subject: [PATCH 033/114] copy handling Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 101 +++++++++++--------- 1 file changed, 54 insertions(+), 47 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 6627e1eaa..d4cb94a17 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -362,61 +362,68 @@ class WaitcntState : public AbstractDenseLattice { return result; } - /// Check for memory dependencies (RAW, WAR, WAW) + /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait WaitcntRequirement checkMemoryDependency(Operation *op, AliasAnalysis &aliasAnalysis) const { - // Check if this is a load or store operation - std::optional currentBase = isLoadOrStoreOp(op); - if (!currentBase) - return {}; - - bool isCurrentLoad = isLoadOp(op).has_value(); - bool isCurrentStore = isStoreOp(op).has_value(); - - WaitcntRequirement result; - for (auto &pendingOps : pendingOpsLists) { - if (pendingOps->empty()) - continue; - // Search from the back to find the most recent dependency - for (Operation *pendingOp : llvm::reverse(pendingOps->ops)) { - std::optional pendingBase = isLoadOrStoreOp(pendingOp); - if (!pendingBase) - continue; - - if (!mayAlias(*currentBase, *pendingBase, aliasAnalysis)) + // std::optional currentBase = isLoadOrStoreOp(op); + auto checkMemref = [&](Value memref, bool isCurrentLoad, + bool isCurrentStore) -> WaitcntRequirement { + WaitcntRequirement result; + for (auto &pendingOps : pendingOpsLists) { + if (pendingOps->empty()) continue; - bool isPendingLoad = isLoadOp(pendingOp).has_value(); - bool isPendingStore = isStoreOp(pendingOp).has_value(); - - // Check for dependencies: - // RAW: current load after pending store - // WAR: current store after pending load - // WAW: current store after pending store - bool hasRAW = isCurrentLoad && isPendingStore; - bool hasWAR = isCurrentStore && isPendingLoad; - bool hasWAW = isCurrentStore && isPendingStore; - - if (hasRAW || hasWAR || hasWAW) { - // Found dependency - compute requirement by counting forward from - // here - auto it = llvm::find(pendingOps->ops, pendingOp); - auto req = - WaitcntRequirement::getOperationRequirement(pendingOp, true); - for (Operation *countOp : - llvm::make_range(std::next(it), pendingOps->ops.end())) { - auto opReq = - WaitcntRequirement::getOperationRequirement(countOp, false); - if (!req.isSameCounterType(opReq)) - continue; - req = req + opReq; - } - result.merge(req); + // Search from the back to find the most recent dependency + for (Operation *pendingOp : llvm::reverse(pendingOps->ops)) { + auto checkPendingMemref = + [&](Value pendingMemref, bool isPendingLoad, + bool isPendingStore) -> WaitcntRequirement { + WaitcntRequirement pendingResult; + if (!mayAlias(memref, pendingMemref, aliasAnalysis)) + return pendingResult; + + // Check for dependencies: + // RAW: current load after pending store + // WAR: current store after pending load + // WAW: current store after pending store + bool hasRAW = isCurrentLoad && isPendingStore; + bool hasWAR = isCurrentStore && isPendingLoad; + bool hasWAW = isCurrentStore && isPendingStore; + + if (hasRAW || hasWAR || hasWAW) { + // Found dependency - compute requirement by counting forward from + // here + auto it = llvm::find(pendingOps->ops, pendingOp); + auto req = + WaitcntRequirement::getOperationRequirement(pendingOp, true); + for (Operation *countOp : + llvm::make_range(std::next(it), pendingOps->ops.end())) { + auto opReq = + WaitcntRequirement::getOperationRequirement(countOp, false); + if (!req.isSameCounterType(opReq)) + continue; + req = req + opReq; + } + pendingResult.merge(req); + } + return pendingResult; + }; + if (auto loadBase = isLoadOp(pendingOp)) + result.merge(checkPendingMemref(*loadBase, true, false)); + if (auto storeBase = isStoreOp(pendingOp)) + result.merge(checkPendingMemref(*storeBase, false, true)); } } - } + return result; + }; + // TODO: atomics will have both load and store flags set + WaitcntRequirement result; + if (auto loadBase = isLoadOp(op)) + result.merge(checkMemref(*loadBase, true, false)); + if (auto storeBase = isStoreOp(op)) + result.merge(checkMemref(*storeBase, false, true)); return result; } From fa7e489a859e909ff5beec2f7790d5906f77ef55 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 21:49:52 +0100 Subject: [PATCH 034/114] memref.copy tests Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 74 +++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 7a389bad0..c6b68f751 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -336,3 +336,77 @@ func.func @loop_load_before_store(%lb: index, %ub: index, %step: index, %memref: // CHECK: return return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @memref_copy_raw_source +func.func @memref_copy_raw_source(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { + // Store to source + // CHECK: vector.store + vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy from source - RAW dependency (reads from source that was just written) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @memref_copy_waw_target +func.func @memref_copy_waw_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) { + // Store to destination + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy to destination - WAW dependency (writes to target that was just written) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return + return +} + +// CHECK-LABEL: func.func @memref_copy_war_target +func.func @memref_copy_war_target(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // Load from destination + // CHECK: %[[RESULT:.*]] = vector.load + %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy to destination - WAR dependency (writes to target that was just read) + // The copy's wait also synchronizes the load, so return doesn't need another wait + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // CHECK: return %[[RESULT]] + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @memref_copy_both_dependencies +func.func @memref_copy_both_dependencies(%src: memref<1024xf32>, %dst: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { + // Store to source + // CHECK: vector.store + vector.store %data, %src[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to destination + // CHECK: vector.store + vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // Copy needs to wait for both stores: + // - RAW on source (copy reads from source) + // - WAW on target (copy writes to destination) + // Both stores alias with their respective memrefs, so we need wait(0) + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: memref.copy + memref.copy %src, %dst : memref<1024xf32> to memref<1024xf32> + + // Load from destination after copy - RAW dependency with copy + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: %[[RESULT:.*]] = vector.load + %result = vector.load %dst[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: return %[[RESULT]] + return %result : vector<4xf32> +} From 83f6ef1ff25bc50560a70be9210e0fb3c15fdfbf Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Dec 2025 22:08:02 +0100 Subject: [PATCH 035/114] gather-to-lds Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 4 ++++ water/test/Transforms/insert-waitcnt.mlir | 23 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index d4cb94a17..fd694d20b 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -39,6 +39,8 @@ static std::optional isLoadOp(Operation *op) { return load.getMemRef(); if (auto copy = dyn_cast(op)) return copy.getSource(); + if (auto gather = dyn_cast(op)) + return gather.getSrc(); return std::nullopt; } @@ -52,6 +54,8 @@ static std::optional isStoreOp(Operation *op) { return store.getMemRef(); if (auto copy = dyn_cast(op)) return copy.getTarget(); + if (auto gather = dyn_cast(op)) + return gather.getDst(); return std::nullopt; } diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index c6b68f751..ed6b2ec7b 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -410,3 +410,26 @@ func.func @memref_copy_both_dependencies(%src: memref<1024xf32>, %dst: memref<10 // CHECK-NEXT: return %[[RESULT]] return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @gather_to_lds +func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu.address_space>, %data: vector<4xf32>, %src_offset: index, %dst_offset: index) -> vector<4xf32> { + // Store to global memory + // CHECK: vector.store + vector.store %data, %global[%src_offset] : memref<1024xf32>, vector<4xf32> + + // Gather from global to LDS - has both RAW (reads from global) and acts as store to LDS + // Should wait for global store using load counter + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: amdgpu.gather_to_lds + amdgpu.gather_to_lds %global[%src_offset], %lds[%dst_offset] : f32, memref<1024xf32>, memref<1024xf32, #gpu.address_space> + + // Load from LDS - RAW dependency with gather writing to LDS + // Should wait for LDS operation using ds counter + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: %[[RESULT:.*]] = vector.load + %result = vector.load %lds[%dst_offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: return %[[RESULT]] + return %result : vector<4xf32> +} From de069b9e105224546c851d04092daada74fc8aed Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 9 Dec 2025 01:19:41 +0100 Subject: [PATCH 036/114] double buffering test WIP Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index ed6b2ec7b..6616b6bcc 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -433,3 +433,45 @@ func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu. // CHECK-NEXT: return %[[RESULT]] return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @double_buffering +func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + // Allocate LDS buffers to guarantee no aliasing with global src + %dst0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %dst1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + // Allocate output buffer to guarantee no aliasing with src + %out = memref.alloc() : memref<1024xf32> + + // Initial copy to buffer 0 before the loop + // CHECK: memref.copy + memref.copy %src, %dst0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %dst0, %next = %dst1) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // Use current buffer which was copied in the previous iteration + // RAW dependency with the copy that wrote to current buffer (LDS operation) + // Distance = 1 because there's 1 copy after it (the copy to next buffer) + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK-NEXT: %[[DATA:.*]] = vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Copy to next buffer for the next iteration (asynchronous, can overlap with computation) + // Reads from global (src), writes to LDS (next) + // No wait needed - src is read-only, no dependency with previous iteration + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Process the data - write to separate output buffer (no dependency with src) + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // Swap buffers for next iteration + // CHECK: scf.yield + scf.yield %next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} From b089085aff97fe62074dae7fe43d7d95aa123518 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 9 Dec 2025 17:56:16 +0100 Subject: [PATCH 037/114] tokens WIP Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 34 +++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index fd694d20b..771bbd0ac 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -79,21 +79,38 @@ static bool isWorkgroupAddressSpace(MemRefType type) { /// Shared pending operations list for structural sharing struct PendingOperations { PendingOperations() = default; - PendingOperations(SmallVector &&ops) : ops(std::move(ops)) {} + PendingOperations(SmallVector &&ops, + SmallVector> &&opsTokens) + : ops(std::move(ops)), opsTokens(std::move(opsTokens)) {} + + void addOp(Operation *op) { + ops.push_back(op); + if (auto memref = isStoreOp(op)) { + opsTokens.push_back({*memref}); + } else if (isLoadOp(op)) { + opsTokens.push_back({op->getResult(0)}); + } else { + assert(false && "Expected load or store operation"); + } + } size_t size() const { return ops.size(); } bool empty() const { return ops.empty(); } bool hasSameTail(const PendingOperations &other) const { - for (const auto &[op1, op2] : - llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops))) { + for (const auto &[op1, op2, tok1, tok2] : + llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops), + llvm::reverse(opsTokens), llvm::reverse(other.opsTokens))) { if (op1 != op2) return false; + if (tok1 != tok2) + return false; } return true; } SmallVector ops; + SmallVector> opsTokens; }; /// Waitcnt requirement for synchronization @@ -257,7 +274,7 @@ class WaitcntState : public AbstractDenseLattice { cow(); } for (auto &pendingOps : pendingOpsLists) - pendingOps->ops.push_back(op); + pendingOps->addOp(op); pendingOpsSet.insert(op); } @@ -278,8 +295,10 @@ class WaitcntState : public AbstractDenseLattice { requirement = req; for (auto &pendingOps : pendingOpsLists) { SmallVector newPending; + SmallVector> newPendingTokens; WaitcntRequirement runningRequirement; - for (Operation *op : llvm::reverse(pendingOps->ops)) { + for (auto [op, tok] : llvm::zip(llvm::reverse(pendingOps->ops), + llvm::reverse(pendingOps->opsTokens))) { WaitcntRequirement opReq = WaitcntRequirement::getOperationRequirement(op, false); runningRequirement = runningRequirement + opReq; @@ -287,12 +306,15 @@ class WaitcntState : public AbstractDenseLattice { continue; newPending.push_back(op); + newPendingTokens.push_back(tok); } if (newPending.size() == pendingOps->ops.size()) continue; std::reverse(newPending.begin(), newPending.end()); - pendingOps = std::make_shared(std::move(newPending)); + std::reverse(newPendingTokens.begin(), newPendingTokens.end()); + pendingOps = std::make_shared( + std::move(newPending), std::move(newPendingTokens)); } // Remove empty lists From 84cb14bb3428721c818d3b36907d5e202cc789b0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 9 Dec 2025 18:07:03 +0100 Subject: [PATCH 038/114] printing Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 25 ++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 771bbd0ac..74ac6f1ee 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -76,6 +76,10 @@ static bool isWorkgroupAddressSpace(MemRefType type) { return attr && attr.getValue() == gpu::AddressSpace::Workgroup; } +template static void print_range(raw_ostream &os, T &&range) { + llvm::interleaveComma(range, os, [&](const auto &item) { os << item; }); +} + /// Shared pending operations list for structural sharing struct PendingOperations { PendingOperations() = default; @@ -97,6 +101,8 @@ struct PendingOperations { size_t size() const { return ops.size(); } bool empty() const { return ops.empty(); } + auto opsAndTokens() const { return llvm::zip(ops, opsTokens); } + bool hasSameTail(const PendingOperations &other) const { for (const auto &[op1, op2, tok1, tok2] : llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops), @@ -109,6 +115,15 @@ struct PendingOperations { return true; } + void print(raw_ostream &os) const { + os << "PendingOperations: ops=["; + llvm::interleaveComma(opsAndTokens(), os, [&](const auto &opAndTok) { + os << std::get<0>(opAndTok) << "|"; + print_range(os, std::get<1>(opAndTok)); + }); + os << "]"; + } + SmallVector ops; SmallVector> opsTokens; }; @@ -260,8 +275,7 @@ class WaitcntState : public AbstractDenseLattice { os << "WaitcntState: pending ops ["; for (auto &pendingOps : pendingOpsLists) { os << "["; - llvm::interleaveComma(pendingOps->ops, os, - [&](Operation *op) { os << *op; }); + pendingOps->print(os); os << "]"; } os << "], requirement: " << requirement; @@ -297,8 +311,9 @@ class WaitcntState : public AbstractDenseLattice { SmallVector newPending; SmallVector> newPendingTokens; WaitcntRequirement runningRequirement; - for (auto [op, tok] : llvm::zip(llvm::reverse(pendingOps->ops), - llvm::reverse(pendingOps->opsTokens))) { + for (const auto &[op, tok] : + llvm::zip(llvm::reverse(pendingOps->ops), + llvm::reverse(pendingOps->opsTokens))) { WaitcntRequirement opReq = WaitcntRequirement::getOperationRequirement(op, false); runningRequirement = runningRequirement + opReq; @@ -308,7 +323,7 @@ class WaitcntState : public AbstractDenseLattice { newPending.push_back(op); newPendingTokens.push_back(tok); } - if (newPending.size() == pendingOps->ops.size()) + if (newPending.size() == pendingOps->size()) continue; std::reverse(newPending.begin(), newPending.end()); From 351c89f869db30047e5fb59f421482b1885c4a46 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 9 Dec 2025 20:37:29 +0100 Subject: [PATCH 039/114] control flow WIP Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 74ac6f1ee..d281e56f0 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -512,6 +512,15 @@ class WaitcntState : public AbstractDenseLattice { void resetPendingOpsSet() { pendingOpsSet.clear(); } }; +static ValueRange getRegionResults(ArrayRef successors, + Region *region) { + for (const auto &successor : successors) { + if (successor.getSuccessor() == region) + return successor.getSuccessorInputs(); + } + assert(false && "Region not found, malfoemrf SCF op?"); +} + /// Dense forward dataflow analysis for waitcnt insertion class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { public: @@ -572,6 +581,31 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { return success(); } + void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, + std::optional regionFrom, + std::optional regionTo, + const WaitcntState &before, + WaitcntState *after) override { + LDBG() << "Visiting region branch control flow transfer: " << *branch; + LDBG() << " Region from: " << regionFrom; + LDBG() << " Region to: " << regionTo; + LDBG() << " Before: " << before; + LDBG() << " After: " << *after; + + SmallVector successors; + branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + + ValueRange newValues; + if (regionTo) { + Region ®ion = branch->getRegions()[*regionTo]; + newValues = getRegionResults(successors, ®ion); + } else { + newValues = getRegionResults(successors, nullptr); + } + + propagateIfChanged(after, after->join(before)); + } + private: AliasAnalysis &aliasAnalysis; }; From 96ee3333e9ef3bd35bfe4c195fa6478705a04497 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Dec 2025 13:21:40 +0100 Subject: [PATCH 040/114] tokens bookkeeping Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 87 ++++++++++++++++++--- 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index d281e56f0..0bfd1ffeb 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallVector.h" @@ -82,9 +83,11 @@ template static void print_range(raw_ostream &os, T &&range) { /// Shared pending operations list for structural sharing struct PendingOperations { + using TokenContainer = SmallVector; + PendingOperations() = default; PendingOperations(SmallVector &&ops, - SmallVector> &&opsTokens) + SmallVector &&opsTokens) : ops(std::move(ops)), opsTokens(std::move(opsTokens)) {} void addOp(Operation *op) { @@ -115,6 +118,17 @@ struct PendingOperations { return true; } + void updateTokens( + llvm::function_ref &)> updateFunc) { + for (TokenContainer &tokens : opsTokens) { + TokenContainer newTok; + for (Value tok : tokens) + updateFunc(tok, newTok); + + tokens = std::move(newTok); + } + } + void print(raw_ostream &os) const { os << "PendingOperations: ops=["; llvm::interleaveComma(opsAndTokens(), os, [&](const auto &opAndTok) { @@ -125,7 +139,7 @@ struct PendingOperations { } SmallVector ops; - SmallVector> opsTokens; + SmallVector opsTokens; }; /// Waitcnt requirement for synchronization @@ -309,7 +323,7 @@ class WaitcntState : public AbstractDenseLattice { requirement = req; for (auto &pendingOps : pendingOpsLists) { SmallVector newPending; - SmallVector> newPendingTokens; + SmallVector newPendingTokens; WaitcntRequirement runningRequirement; for (const auto &[op, tok] : llvm::zip(llvm::reverse(pendingOps->ops), @@ -357,6 +371,12 @@ class WaitcntState : public AbstractDenseLattice { resetPendingOpsSet(); } + void updateTokens( + llvm::function_ref &)> updateFunc) { + for (auto &pendingOps : pendingOpsLists) + pendingOps->updateTokens(updateFunc); + } + void resetRequirement() { requirement = {}; } /// Get the required waitcnt values @@ -512,13 +532,13 @@ class WaitcntState : public AbstractDenseLattice { void resetPendingOpsSet() { pendingOpsSet.clear(); } }; -static ValueRange getRegionResults(ArrayRef successors, - Region *region) { +static RegionSuccessor getRegionResults(ArrayRef successors, + Region *region) { for (const auto &successor : successors) { if (successor.getSuccessor() == region) - return successor.getSuccessorInputs(); + return successor; } - assert(false && "Region not found, malfoemrf SCF op?"); + llvm_unreachable("Region not found, malformed SCF op?"); } /// Dense forward dataflow analysis for waitcnt insertion @@ -595,15 +615,56 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { SmallVector successors; branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); - ValueRange newValues; - if (regionTo) { - Region ®ion = branch->getRegions()[*regionTo]; - newValues = getRegionResults(successors, ®ion); + auto destSuccessor = [&]() -> RegionSuccessor { + if (regionTo) { + Region ®ion = branch->getRegions()[*regionTo]; + return getRegionResults(successors, ®ion); + } else { + return getRegionResults(successors, nullptr); + } + }(); + // Dest values are either nested block args or branch op results. + ValueRange destValues = destSuccessor.getSuccessorInputs(); + + // Map from input values to dest values. + llvm::SmallDenseMap valuesMapping; + if (regionFrom) { + Region ®ion = branch->getRegions()[*regionFrom]; + for (Block &block : region) { + auto term = + dyn_cast(block.getTerminator()); + if (!term) + continue; + + ValueRange source = + term.getMutableSuccessorOperands(destSuccessor).getAsOperandRange(); + for (auto [source, dest] : llvm::zip(source, destValues)) + valuesMapping[source] = dest; + } } else { - newValues = getRegionResults(successors, nullptr); + ValueRange source = branch.getEntrySuccessorOperands(destSuccessor); + for (auto [source, dest] : llvm::zip(source, destValues)) + valuesMapping[source] = dest; } - propagateIfChanged(after, after->join(before)); + DominanceInfo dom; + + WaitcntState newState = before; + auto tokenUpdateFunc = [&](Value value, SmallVectorImpl &newTokens) { + // Keep the token if it dominates current op as user can use it directly. + if (dom.properlyDominates(value, branch)) + newTokens.push_back(value); + + // Add token propagated through region control flow. + if (Value value = valuesMapping.lookup(value)) + if (newTokens.empty() || newTokens.back() != value) + newTokens.push_back(value); + }; + newState.updateTokens(tokenUpdateFunc); + + LDBG() << " New state: " << newState; + + propagateIfChanged(after, after->join(newState)); } private: From ebd30358e7b641f4c04ca9b75a925936c08bf2be Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Dec 2025 14:44:15 +0100 Subject: [PATCH 041/114] token stuff Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 39 ++++++++++--------- water/test/Transforms/insert-waitcnt.mlir | 42 --------------------- 2 files changed, 22 insertions(+), 59 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 0bfd1ffeb..fb1e6f6d6 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -77,8 +77,10 @@ static bool isWorkgroupAddressSpace(MemRefType type) { return attr && attr.getValue() == gpu::AddressSpace::Workgroup; } -template static void print_range(raw_ostream &os, T &&range) { +template +static raw_ostream &print_range(raw_ostream &os, T &&range) { llvm::interleaveComma(range, os, [&](const auto &item) { os << item; }); + return os; } /// Shared pending operations list for structural sharing @@ -92,13 +94,12 @@ struct PendingOperations { void addOp(Operation *op) { ops.push_back(op); - if (auto memref = isStoreOp(op)) { - opsTokens.push_back({*memref}); - } else if (isLoadOp(op)) { - opsTokens.push_back({op->getResult(0)}); - } else { - assert(false && "Expected load or store operation"); - } + auto &back = opsTokens.emplace_back(); + if (auto memref = isStoreOp(op)) + back.push_back(*memref); + + if (auto memref = isLoadOp(op)) + back.push_back(*memref); } size_t size() const { return ops.size(); } @@ -230,12 +231,12 @@ inline raw_ostream &operator<<(raw_ostream &os, return os; } -static bool mayAlias(Value lhs, Value rhs, AliasAnalysis &aliasAnalysis) { +static bool mayAlias(Value lhs, Value rhs, ArrayRef tokens) { if (isWorkgroupAddressSpace(cast(lhs.getType())) != isWorkgroupAddressSpace(cast(rhs.getType()))) return false; - return !aliasAnalysis.alias(lhs, rhs).isNo(); + return llvm::is_contained(tokens, lhs); } /// Lattice state tracking pending asynchronous operations @@ -436,12 +437,14 @@ class WaitcntState : public AbstractDenseLattice { continue; // Search from the back to find the most recent dependency - for (Operation *pendingOp : llvm::reverse(pendingOps->ops)) { + for (const auto &[pendingOp, pendingTokens] : + llvm::zip(llvm::reverse(pendingOps->ops), + llvm::reverse(pendingOps->opsTokens))) { auto checkPendingMemref = [&](Value pendingMemref, bool isPendingLoad, bool isPendingStore) -> WaitcntRequirement { WaitcntRequirement pendingResult; - if (!mayAlias(memref, pendingMemref, aliasAnalysis)) + if (!mayAlias(memref, pendingMemref, pendingTokens)) return pendingResult; // Check for dependencies: @@ -502,8 +505,8 @@ class WaitcntState : public AbstractDenseLattice { if (pendingOps.use_count() > 1) { auto newPending = std::make_shared(); if (pendingOps) - newPending->ops = pendingOps->ops; - pendingOps = newPending; + *newPending = *pendingOps; + pendingOps = std::move(newPending); } } } @@ -572,6 +575,8 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { if (auto memReq = before.checkMemoryDependency(op, aliasAnalysis)) { LDBG() << " Memory dependency: " << memReq; opRequirement.merge(memReq); + } else { + LDBG() << " No memory dependency"; } // Check if this is an existing memory_counter_wait operation @@ -656,9 +661,9 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { newTokens.push_back(value); // Add token propagated through region control flow. - if (Value value = valuesMapping.lookup(value)) - if (newTokens.empty() || newTokens.back() != value) - newTokens.push_back(value); + if (Value mappedValue = valuesMapping.lookup(value)) + if (newTokens.empty() || newTokens.back() != mappedValue) + newTokens.push_back(mappedValue); }; newState.updateTokens(tokenUpdateFunc); diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 6616b6bcc..ed6b2ec7b 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -433,45 +433,3 @@ func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu. // CHECK-NEXT: return %[[RESULT]] return %result : vector<4xf32> } - -// CHECK-LABEL: func.func @double_buffering -func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { - // Allocate LDS buffers to guarantee no aliasing with global src - %dst0 = memref.alloc() : memref<1024xf32, #gpu.address_space> - %dst1 = memref.alloc() : memref<1024xf32, #gpu.address_space> - // Allocate output buffer to guarantee no aliasing with src - %out = memref.alloc() : memref<1024xf32> - - // Initial copy to buffer 0 before the loop - // CHECK: memref.copy - memref.copy %src, %dst0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // CHECK: scf.for - scf.for %i = %lb to %ub step %step iter_args(%current = %dst0, %next = %dst1) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { - // Use current buffer which was copied in the previous iteration - // RAW dependency with the copy that wrote to current buffer (LDS operation) - // Distance = 1 because there's 1 copy after it (the copy to next buffer) - // CHECK: amdgpu.memory_counter_wait ds(1) - // CHECK-NEXT: %[[DATA:.*]] = vector.load - %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> - - // Copy to next buffer for the next iteration (asynchronous, can overlap with computation) - // Reads from global (src), writes to LDS (next) - // No wait needed - src is read-only, no dependency with previous iteration - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: memref.copy - memref.copy %src, %next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> - - // Process the data - write to separate output buffer (no dependency with src) - // CHECK-NOT: amdgpu.memory_counter_wait - // CHECK: vector.store - vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> - - // Swap buffers for next iteration - // CHECK: scf.yield - scf.yield %next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> - } - - // CHECK: return - return -} From e4eb580587d43b8c3f12c6d183338dd6fc8efdb0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Dec 2025 14:46:49 +0100 Subject: [PATCH 042/114] -aliasAnalysis Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index fb1e6f6d6..0ac10c611 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -6,7 +6,6 @@ #include "water/Transforms/Passes.h" -#include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" @@ -425,8 +424,7 @@ class WaitcntState : public AbstractDenseLattice { } /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait - WaitcntRequirement checkMemoryDependency(Operation *op, - AliasAnalysis &aliasAnalysis) const { + WaitcntRequirement checkMemoryDependency(Operation *op) const { // std::optional currentBase = isLoadOrStoreOp(op); auto checkMemref = [&](Value memref, bool isCurrentLoad, @@ -547,8 +545,8 @@ static RegionSuccessor getRegionResults(ArrayRef successors, /// Dense forward dataflow analysis for waitcnt insertion class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { public: - WaitcntAnalysis(DataFlowSolver &solver, AliasAnalysis &aliasAnalysis) - : DenseForwardDataFlowAnalysis(solver), aliasAnalysis(aliasAnalysis) {} + explicit WaitcntAnalysis(DataFlowSolver &solver) + : DenseForwardDataFlowAnalysis(solver) {} void setToEntryState(WaitcntState *lattice) override { propagateIfChanged(lattice, lattice->reset()); @@ -572,7 +570,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } // Check for memory dependencies (RAW, WAR, WAW) - if (auto memReq = before.checkMemoryDependency(op, aliasAnalysis)) { + if (auto memReq = before.checkMemoryDependency(op)) { LDBG() << " Memory dependency: " << memReq; opRequirement.merge(memReq); } else { @@ -671,9 +669,6 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { propagateIfChanged(after, after->join(newState)); } - -private: - AliasAnalysis &aliasAnalysis; }; /// Pass that inserts wait/synchronization instructions for asynchronous @@ -685,11 +680,9 @@ class WaterInsertWaitcntPass LDBG() << "Running WaterInsertWaitcntPass"; Operation *op = getOperation(); - auto &aliasAnalysis = getAnalysis(); - DataFlowSolver solver; loadBaselineAnalyses(solver); - solver.load(aliasAnalysis); + solver.load(); if (failed(solver.initializeAndRun(op))) { signalPassFailure(); From 9e09741a0936e2319dc68089dd58bacec8f11119 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Dec 2025 14:53:05 +0100 Subject: [PATCH 043/114] double-buffering test Signed-off-by: Ivan Butygin --- water/test/Transforms/insert-waitcnt.mlir | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index ed6b2ec7b..52a739d07 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -433,3 +433,39 @@ func.func @gather_to_lds(%global: memref<1024xf32>, %lds: memref<1024xf32, #gpu. // CHECK-NEXT: return %[[RESULT]] return %result : vector<4xf32> } + +// CHECK-LABEL: func.func @double_buffering +func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the second buffer copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Cannot skip unfortunately + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} From 186b4b6eb178f5d55714ca6abd6ace8f95fb9b82 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Dec 2025 15:32:17 +0100 Subject: [PATCH 044/114] fix reqs and trips buffering Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 18 +++++---- water/test/Transforms/insert-waitcnt.mlir | 41 +++++++++++++++++++++ 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 0ac10c611..7adbc9010 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -104,7 +104,11 @@ struct PendingOperations { size_t size() const { return ops.size(); } bool empty() const { return ops.empty(); } - auto opsAndTokens() const { return llvm::zip(ops, opsTokens); } + auto opsAndTokens() const { + assert(ops.size() == opsTokens.size() && + "ops and opsTokens must have the same size"); + return llvm::zip(ops, opsTokens); + } bool hasSameTail(const PendingOperations &other) const { for (const auto &[op1, op2, tok1, tok2] : @@ -132,7 +136,7 @@ struct PendingOperations { void print(raw_ostream &os) const { os << "PendingOperations: ops=["; llvm::interleaveComma(opsAndTokens(), os, [&](const auto &opAndTok) { - os << std::get<0>(opAndTok) << "|"; + os << *std::get<0>(opAndTok) << "|"; print_range(os, std::get<1>(opAndTok)); }); os << "]"; @@ -288,11 +292,11 @@ class WaitcntState : public AbstractDenseLattice { void print(raw_ostream &os) const override { os << "WaitcntState: pending ops ["; for (auto &pendingOps : pendingOpsLists) { - os << "["; + os << "\n ["; pendingOps->print(os); os << "]"; } - os << "], requirement: " << requirement; + os << "\n ], requirement: " << requirement; } void addPendingOp(Operation *op) { @@ -325,9 +329,7 @@ class WaitcntState : public AbstractDenseLattice { SmallVector newPending; SmallVector newPendingTokens; WaitcntRequirement runningRequirement; - for (const auto &[op, tok] : - llvm::zip(llvm::reverse(pendingOps->ops), - llvm::reverse(pendingOps->opsTokens))) { + for (const auto &[op, tok] : llvm::reverse(pendingOps->opsAndTokens())) { WaitcntRequirement opReq = WaitcntRequirement::getOperationRequirement(op, false); runningRequirement = runningRequirement + opReq; @@ -561,7 +563,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { WaitcntState newState = before; // Check if any operands depend on pending operations (value dependency) - WaitcntRequirement opRequirement; + WaitcntRequirement opRequirement = after->getRequirement(); for (Value operand : op->getOperands()) { if (auto req = before.checkRequirement(operand)) { // Merge this requirement (take minimum for conservative wait) diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 52a739d07..81fcf6777 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -469,3 +469,44 @@ func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %ste // CHECK: return return } + +// CHECK-LABEL: func.func @triple_buffering +func.func @triple_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // Skip the second buffer copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the prev copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} From 0c42d8eafccc61729dad848a1ff8427bd8b313b3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Dec 2025 16:02:34 +0100 Subject: [PATCH 045/114] token cache Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 42 +++++++++++++++++---- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 7adbc9010..bb2fa03c1 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -91,7 +91,7 @@ struct PendingOperations { SmallVector &&opsTokens) : ops(std::move(ops)), opsTokens(std::move(opsTokens)) {} - void addOp(Operation *op) { + TokenContainer &addOp(Operation *op) { ops.push_back(op); auto &back = opsTokens.emplace_back(); if (auto memref = isStoreOp(op)) @@ -99,6 +99,8 @@ struct PendingOperations { if (auto memref = isLoadOp(op)) back.push_back(*memref); + + return back; } size_t size() const { return ops.size(); } @@ -305,8 +307,11 @@ class WaitcntState : public AbstractDenseLattice { } else { cow(); } - for (auto &pendingOps : pendingOpsLists) - pendingOps->addOp(op); + for (auto &pendingOps : pendingOpsLists) { + auto &tokens = pendingOps->addOp(op); + for (Value token : tokens) + pendingOpsTokens.insert(token); + } pendingOpsSet.insert(op); } @@ -432,6 +437,9 @@ class WaitcntState : public AbstractDenseLattice { auto checkMemref = [&](Value memref, bool isCurrentLoad, bool isCurrentStore) -> WaitcntRequirement { WaitcntRequirement result; + if (!isPendingOp(memref)) + return result; + for (auto &pendingOps : pendingOpsLists) { if (pendingOps->empty()) continue; @@ -499,6 +507,7 @@ class WaitcntState : public AbstractDenseLattice { WaitcntRequirement requirement; mutable llvm::SmallDenseSet pendingOpsSet; + mutable llvm::SmallDenseSet pendingOpsTokens; void cow() { for (auto &pendingOps : pendingOpsLists) { @@ -511,28 +520,45 @@ class WaitcntState : public AbstractDenseLattice { } } - bool isPendingOp(Operation *op) const { + bool isPendingOp(llvm::PointerUnion opOrVal) const { if (pendingOpsLists.empty()) return false; // Build the set of pending operations lazily bool found = false; if (pendingOpsSet.empty()) { + assert(pendingOpsTokens.empty() && "pendingOpsTokens must be empty"); + Operation *op = dyn_cast(opOrVal); + Value val = dyn_cast(opOrVal); for (const auto &pendingOps : pendingOpsLists) { - for (Operation *pendingOp : pendingOps->ops) { + for (const auto &[pendingOp, pendingTokens] : + pendingOps->opsAndTokens()) { if (pendingOp == op) found = true; pendingOpsSet.insert(pendingOp); + for (Value token : pendingTokens) { + if (token == val) + found = true; + + pendingOpsTokens.insert(token); + } } } - return found; } - return pendingOpsSet.contains(op); + if (found) + return true; + + return isa(opOrVal) + ? pendingOpsSet.contains(cast(opOrVal)) + : pendingOpsTokens.contains(cast(opOrVal)); } - void resetPendingOpsSet() { pendingOpsSet.clear(); } + void resetPendingOpsSet() { + pendingOpsSet.clear(); + pendingOpsTokens.clear(); + } }; static RegionSuccessor getRegionResults(ArrayRef successors, From 3f9e8b63086ff4f29e86a68821ed649b39bee905 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Dec 2025 22:05:55 +0100 Subject: [PATCH 046/114] fix witcount ops Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index bb2fa03c1..499d99035 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -739,7 +739,8 @@ class WaterInsertWaitcntPass rewriter.setInsertionPoint(operation); amdgpu::MemoryCounterWaitOp::create( rewriter, operation->getLoc(), getAttr(req.getLoadCnt()), - getAttr(req.getStoreCnt()), getAttr(req.getDsCnt()), nullptr); + getAttr(req.getStoreCnt()), getAttr(req.getDsCnt()), nullptr, + nullptr); }); } }; From 3f23651ee3d0768012661d4e468fc9ef04bfd7c4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Dec 2025 22:06:13 +0100 Subject: [PATCH 047/114] more global ops tests Signed-off-by: Ivan Butygin --- water/test/Transforms/lower-memory-ops.mlir | 53 ++++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 69f1617ed..76c2b00dd 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -30,13 +30,62 @@ func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector return } -// CHECK-LABEL: func.func @vector_load_2xf32 -func.func @vector_load_2xf32(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { +// CHECK-LABEL: func.func @vector_load_b32 +func.func @vector_load_b32(%memref: memref<1024xf32>, %offset: index) -> vector<1xf32> { + // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @vector_load_b64 +func.func @vector_load_b64(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { // CHECK: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> return %result : vector<2xf32> } +// CHECK-LABEL: func.func @vector_load_b96 +func.func @vector_load_b96(%memref: memref<1024xf32>, %offset: index) -> vector<3xf32> { + // CHECK: llvm.inline_asm has_side_effects "global_load_b96 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @vector_load_b128 +func.func @vector_load_b128(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @vector_store_b32 +func.func @vector_store_b32(%memref: memref<1024xf32>, %offset: index, %data: vector<1xf32>) { + // CHECK: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b64 +func.func @vector_store_b64(%memref: memref<1024xf32>, %offset: index, %data: vector<2xf32>) { + // CHECK: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b96 +func.func @vector_store_b96(%memref: memref<1024xf32>, %offset: index, %data: vector<3xf32>) { + // CHECK: llvm.inline_asm has_side_effects "global_store_b96 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @vector_store_b128 +func.func @vector_store_b128(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { + // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> + return +} + // CHECK-LABEL: func.func @load_store_sequence func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { // Test lowering of load/store sequence From 35bc8050ba93ef3635ad090733814141c69178bb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Dec 2025 22:44:24 +0100 Subject: [PATCH 048/114] buffer ops Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 191 ++++++++++++++++--- water/test/Transforms/lower-memory-ops.mlir | 80 ++++++++ 2 files changed, 241 insertions(+), 30 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index a2d2a8cf8..7e93bfdc0 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -6,6 +6,7 @@ #include "water/Transforms/Passes.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -99,9 +100,94 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); } +/// Get buffer instruction suffix based on bit width +static FailureOr getBufferSuffix(unsigned bitWidth) { + switch (bitWidth) { + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +/// Extract buffer descriptor pointer from a fat_raw_buffer memref +static Value extractBufferDescriptor(IRRewriter &rewriter, Location loc, + Value memref) { + // Create proper memref descriptor struct type: {ptr, ptr, offset, sizes..., + // strides...} + auto memrefType = cast(memref.getType()); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + SmallVector descriptorFields{ptrType, ptrType, + i64Type}; // allocated, aligned, offset + // Add sizes and strides for each dimension + for (int64_t i = 0; i < memrefType.getRank(); ++i) + descriptorFields.push_back(i64Type); // size + + for (int64_t i = 0; i < memrefType.getRank(); ++i) + descriptorFields.push_back(i64Type); // stride + + auto memrefDescType = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), descriptorFields); + + Value memrefDescVal = + UnrealizedConversionCastOp::create(rewriter, loc, memrefDescType, memref) + .getResult(0); + + // Use MemRefDescriptor to extract aligned pointer + MemRefDescriptor memrefDesc(memrefDescVal); + return memrefDesc.alignedPtr(rewriter, loc); +} + +/// Lower vector.load to AMDGPU buffer load inline assembly +static LogicalResult lowerVectorLoadBuffer(vector::LoadOp loadOp, + IRRewriter &rewriter) { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + FailureOr suffix = getBufferSuffix(bitWidth); + if (failed(suffix)) + return loadOp.emitError("unsupported vector buffer load bit width: ") + << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "buffer_load_dwordx4 $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); + + // Constraints: "=v" for output (VGPR), "v" for offset (VGPR), "s" for + // descriptor (SGPR[4]) + StringRef constraints = "=v,v,s"; + + // Compute offset in bytes + Value addr = + computeMemrefAddress(rewriter, loc, loadOp.getBase(), loadOp.getIndices(), + vectorType.getElementTypeBitWidth()); + + // Extract buffer descriptor pointer from memref + Value bufferDesc = extractBufferDescriptor(rewriter, loc, loadOp.getBase()); + + // Create inline assembly operation + auto asmOp = + createInlineAsm(rewriter, loc, vectorType, ValueRange{addr, bufferDesc}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + /// Lower vector.load to LLVM inline assembly (global_load_*) -static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, - IRRewriter &rewriter) { +static LogicalResult lowerVectorLoadGlobal(vector::LoadOp loadOp, + IRRewriter &rewriter) { auto vectorType = loadOp.getVectorType(); unsigned bitWidth = vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); @@ -133,9 +219,49 @@ static LogicalResult lowerVectorLoad(vector::LoadOp loadOp, return success(); } +/// Lower vector.store to AMDGPU buffer store inline assembly +static LogicalResult lowerVectorStoreBuffer(vector::StoreOp storeOp, + IRRewriter &rewriter) { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + FailureOr suffix = getBufferSuffix(bitWidth); + if (failed(suffix)) + return storeOp.emitError("unsupported vector buffer store bit width: ") + << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + + // Build inline assembly: "buffer_store_dwordx4 $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_store_" + *suffix + " $0, $1, $2, 0 offen").str(); + + // Constraints: "v" for data (VGPR), "v" for offset (VGPR), "s" for descriptor + // (SGPR[4]) + StringRef constraints = "v,v,s"; + + // Compute offset in bytes + Value addr = computeMemrefAddress(rewriter, loc, storeOp.getBase(), + storeOp.getIndices(), + vectorType.getElementTypeBitWidth()); + + // Extract buffer descriptor pointer from memref + Value bufferDesc = extractBufferDescriptor(rewriter, loc, storeOp.getBase()); + + // Create inline assembly operation (no result for store) + createInlineAsm(rewriter, loc, TypeRange{}, + ValueRange{storeOp.getValueToStore(), addr, bufferDesc}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + /// Lower vector.store to LLVM inline assembly (global_store_*) -static LogicalResult lowerVectorStore(vector::StoreOp storeOp, - IRRewriter &rewriter) { +static LogicalResult lowerVectorStoreGlobal(vector::StoreOp storeOp, + IRRewriter &rewriter) { auto vectorType = cast(storeOp.getValueToStore().getType()); unsigned bitWidth = vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); @@ -169,29 +295,26 @@ static LogicalResult lowerVectorStore(vector::StoreOp storeOp, return success(); } -/// Wrapper functions for operation lowering -static LogicalResult lowerLoadOp(Operation *op, IRRewriter &rewriter) { - return lowerVectorLoad(cast(op), rewriter); -} +/// Check if a memref uses AMDGPU fat_raw_buffer address space +static bool usesBufferAddressSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); -static LogicalResult lowerStoreOp(Operation *op, IRRewriter &rewriter) { - return lowerVectorStore(cast(op), rewriter); -} + if (!memorySpace) + return false; -/// Operation lowering handler entry -struct OpLoweringHandler { - TypeID typeID; - LogicalResult (*lowerFn)(Operation *, IRRewriter &); -}; + // Check for #amdgpu.address_space attribute + if (auto enumAttr = dyn_cast(memorySpace)) { + return enumAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer; + } -/// Table of lowering handlers for different operation types -static const OpLoweringHandler loweringHandlers[] = { - {TypeID::get(), lowerLoadOp}, - {TypeID::get(), lowerStoreOp}, -}; + return false; +} -/// Pass that lowers high-level memory operations to LLVM inline assembly -/// for AMDGPU global memory instructions. +/// Pass that lowers high-level memory operations to AMDGPU memory instructions. +/// Uses buffer operations for memrefs with +/// #amdgpu.address_space, and global operations for all other +/// memrefs. class WaterLowerMemoryOpsPass : public water::impl::WaterLowerMemoryOpsBase { public: @@ -199,13 +322,21 @@ class WaterLowerMemoryOpsPass IRRewriter rewriter(&getContext()); auto walkFn = [&](Operation *op) { - TypeID opTypeID = op->getName().getTypeID(); - for (const auto &handler : loweringHandlers) { - if (handler.typeID == opTypeID) { - if (failed(handler.lowerFn(op, rewriter))) - return WalkResult::interrupt(); - return WalkResult::advance(); - } + if (auto loadOp = dyn_cast(op)) { + LogicalResult result = usesBufferAddressSpace(loadOp.getBase()) + ? lowerVectorLoadBuffer(loadOp, rewriter) + : lowerVectorLoadGlobal(loadOp, rewriter); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto storeOp = dyn_cast(op)) { + LogicalResult result = usesBufferAddressSpace(storeOp.getBase()) + ? lowerVectorStoreBuffer(storeOp, rewriter) + : lowerVectorStoreGlobal(storeOp, rewriter); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); } return WalkResult::advance(); }; diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 76c2b00dd..65cdfe7b6 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -99,3 +99,83 @@ func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, % // CHECK: return return } + +// ----- +// Buffer operations tests + +// CHECK-LABEL: func.func @buffer_load_b32 +func.func @buffer_load_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<1xf32> { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b64 +func.func @buffer_load_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<2xf32> { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx2 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b96 +func.func @buffer_load_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<3xf32> { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx3 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @buffer_load_b128 +func.func @buffer_load_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<4xf32> { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @buffer_store_b32 +func.func @buffer_store_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<1xf32>) { + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b64 +func.func @buffer_store_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<2xf32>) { + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx2 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b96 +func.func @buffer_store_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<3xf32>) { + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx3 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @buffer_store_b128 +func.func @buffer_store_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<4xf32>) { + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @mixed_global_and_buffer +func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) { + // Load from global memory (should use global_load) + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to buffer memory (should use buffer_store) + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + vector.store %global_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + // Load from buffer memory (should use buffer_load) + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + %buffer_data = vector.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + // Store to global memory (should use global_store) + // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %buffer_data, %global[%offset] : memref<1024xf32>, vector<4xf32> + + return +} From 662c8c314df10d133f2e95ce809c507951b115b8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Dec 2025 22:55:58 +0100 Subject: [PATCH 049/114] buffer offsets Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 71 ++++++++++++-------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 7e93bfdc0..96ac2a80f 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -54,17 +54,13 @@ static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, /*operand_attrs=*/ArrayAttr{}); } -/// Compute the final address for a memref access with indices -static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, - Value memref, ValueRange indices, - unsigned elementBitWidth) { - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto i64Type = rewriter.getI64Type(); - - // Extract strided metadata to get base pointer, offset, sizes, and strides +/// Compute byte offset as i64 for a memref access with indices +static Value computeMemrefByteOffsetI64(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + // Extract strided metadata to get offset and strides auto metadataOp = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); - Value basePtr = metadataOp.getBaseBuffer(); Value offset = metadataOp.getOffset(); // Compute linear index from multidimensional indices @@ -78,20 +74,39 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, arith::IntegerOverflowFlags::nsw); } - // Convert base pointer to i64 - Value basePtrInt = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); - basePtrInt = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtrInt); - - // Convert linear index to i64 and scale by element size + // Convert linear index to byte offset unsigned elementBytes = elementBitWidth / 8; Value elementSize = arith::ConstantIndexOp::create(rewriter, loc, elementBytes); Value byteOffset = arith::MulIOp::create(rewriter, loc, linearIndex, elementSize, arith::IntegerOverflowFlags::nsw); - Value byteOffsetI64 = - arith::IndexCastOp::create(rewriter, loc, i64Type, byteOffset); + + return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), + byteOffset); +} + +/// Compute the final address for a memref access with indices (for global +/// operations) +static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto i64Type = rewriter.getI64Type(); + + // Extract base pointer + auto metadataOp = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + Value basePtr = metadataOp.getBaseBuffer(); + + // Convert base pointer to i64 + Value basePtrInt = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); + basePtrInt = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtrInt); + + // Compute byte offset + Value byteOffsetI64 = computeMemrefByteOffsetI64(rewriter, loc, memref, + indices, elementBitWidth); // Add byte offset to base pointer Value finalAddr = @@ -168,17 +183,18 @@ static LogicalResult lowerVectorLoadBuffer(vector::LoadOp loadOp, // descriptor (SGPR[4]) StringRef constraints = "=v,v,s"; - // Compute offset in bytes - Value addr = - computeMemrefAddress(rewriter, loc, loadOp.getBase(), loadOp.getIndices(), - vectorType.getElementTypeBitWidth()); + // Compute byte offset as i64 (not full address, since buffer descriptor has + // base) + Value offset = computeMemrefByteOffsetI64( + rewriter, loc, loadOp.getBase(), loadOp.getIndices(), + vectorType.getElementTypeBitWidth()); // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, loadOp.getBase()); // Create inline assembly operation auto asmOp = - createInlineAsm(rewriter, loc, vectorType, ValueRange{addr, bufferDesc}, + createInlineAsm(rewriter, loc, vectorType, ValueRange{offset, bufferDesc}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); @@ -242,17 +258,18 @@ static LogicalResult lowerVectorStoreBuffer(vector::StoreOp storeOp, // (SGPR[4]) StringRef constraints = "v,v,s"; - // Compute offset in bytes - Value addr = computeMemrefAddress(rewriter, loc, storeOp.getBase(), - storeOp.getIndices(), - vectorType.getElementTypeBitWidth()); + // Compute byte offset as i64 (not full address, since buffer descriptor has + // base) + Value offset = computeMemrefByteOffsetI64( + rewriter, loc, storeOp.getBase(), storeOp.getIndices(), + vectorType.getElementTypeBitWidth()); // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, storeOp.getBase()); // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{storeOp.getValueToStore(), addr, bufferDesc}, + ValueRange{storeOp.getValueToStore(), offset, bufferDesc}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); From f56a94dcb5e1c621c92a11737db66e00158743ae Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Dec 2025 23:13:26 +0100 Subject: [PATCH 050/114] ds ops Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 1 + water/lib/Transforms/WaterLowerMemoryOps.cpp | 122 +++++++++++++++++-- water/test/Transforms/lower-memory-ops.mlir | 89 ++++++++++++++ 3 files changed, 202 insertions(+), 10 deletions(-) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index e213fa8e5..92f66c812 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -186,6 +186,7 @@ def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { }]; let dependentDialects = [ "::mlir::amdgpu::AMDGPUDialect", + "::mlir::gpu::GPUDialect", "::mlir::LLVM::LLVMDialect", "::mlir::memref::MemRefDialect", "::mlir::vector::VectorDialect", diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 96ac2a80f..b964827d4 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -312,6 +313,83 @@ static LogicalResult lowerVectorStoreGlobal(vector::StoreOp storeOp, return success(); } +/// Lower vector.load to AMDGPU DS load inline assembly +static LogicalResult lowerVectorLoadDS(vector::LoadOp loadOp, + IRRewriter &rewriter) { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + FailureOr suffix = getSizeSuffix(bitWidth); + if (failed(suffix)) + return loadOp.emitError("unsupported vector DS load bit width: ") + << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + // Constraints: "=v" for output (VGPR), "v" for address (VGPR) + StringRef constraints = "=v,v"; + + // Compute byte offset as i64 + Value offset = computeMemrefByteOffsetI64( + rewriter, loc, loadOp.getBase(), loadOp.getIndices(), + vectorType.getElementTypeBitWidth()); + + // DS operations use 32-bit addresses + Value offset32 = + arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); + + // Create inline assembly operation + auto asmOp = createInlineAsm(rewriter, loc, vectorType, ValueRange{offset32}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +/// Lower vector.store to AMDGPU DS store inline assembly +static LogicalResult lowerVectorStoreDS(vector::StoreOp storeOp, + IRRewriter &rewriter) { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + + FailureOr suffix = getSizeSuffix(bitWidth); + if (failed(suffix)) + return storeOp.emitError("unsupported vector DS store bit width: ") + << bitWidth; + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + + // Build inline assembly: "ds_write_b32 $0, $1" + std::string asmStr = ("ds_write_" + *suffix + " $0, $1").str(); + + // Constraints: "v" for address (VGPR), "v" for data (VGPR) + StringRef constraints = "v,v"; + + // Compute byte offset as i64 + Value offset = computeMemrefByteOffsetI64( + rewriter, loc, storeOp.getBase(), storeOp.getIndices(), + vectorType.getElementTypeBitWidth()); + + // DS operations use 32-bit addresses + Value offset32 = + arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); + + // Create inline assembly operation (no result for store) + createInlineAsm(rewriter, loc, TypeRange{}, + ValueRange{offset32, storeOp.getValueToStore()}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + /// Check if a memref uses AMDGPU fat_raw_buffer address space static bool usesBufferAddressSpace(Value memref) { auto memrefType = cast(memref.getType()); @@ -321,17 +399,31 @@ static bool usesBufferAddressSpace(Value memref) { return false; // Check for #amdgpu.address_space attribute - if (auto enumAttr = dyn_cast(memorySpace)) { + if (auto enumAttr = dyn_cast(memorySpace)) return enumAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer; - } + + return false; +} + +/// Check if a memref uses workgroup (LDS) address space +static bool usesWorkgroupAddressSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (!memorySpace) + return false; + + // Check for #gpu.address_space attribute + if (auto enumAttr = dyn_cast(memorySpace)) + return enumAttr.getValue() == gpu::AddressSpace::Workgroup; return false; } /// Pass that lowers high-level memory operations to AMDGPU memory instructions. /// Uses buffer operations for memrefs with -/// #amdgpu.address_space, and global operations for all other -/// memrefs. +/// #amdgpu.address_space, DS operations for memrefs with +/// #gpu.address_space, and global operations for all other memrefs. class WaterLowerMemoryOpsPass : public water::impl::WaterLowerMemoryOpsBase { public: @@ -340,17 +432,27 @@ class WaterLowerMemoryOpsPass auto walkFn = [&](Operation *op) { if (auto loadOp = dyn_cast(op)) { - LogicalResult result = usesBufferAddressSpace(loadOp.getBase()) - ? lowerVectorLoadBuffer(loadOp, rewriter) - : lowerVectorLoadGlobal(loadOp, rewriter); + LogicalResult result = success(); + if (usesBufferAddressSpace(loadOp.getBase())) { + result = lowerVectorLoadBuffer(loadOp, rewriter); + } else if (usesWorkgroupAddressSpace(loadOp.getBase())) { + result = lowerVectorLoadDS(loadOp, rewriter); + } else { + result = lowerVectorLoadGlobal(loadOp, rewriter); + } if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto storeOp = dyn_cast(op)) { - LogicalResult result = usesBufferAddressSpace(storeOp.getBase()) - ? lowerVectorStoreBuffer(storeOp, rewriter) - : lowerVectorStoreGlobal(storeOp, rewriter); + LogicalResult result = success(); + if (usesBufferAddressSpace(storeOp.getBase())) { + result = lowerVectorStoreBuffer(storeOp, rewriter); + } else if (usesWorkgroupAddressSpace(storeOp.getBase())) { + result = lowerVectorStoreDS(storeOp, rewriter); + } else { + result = lowerVectorStoreGlobal(storeOp, rewriter); + } if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 65cdfe7b6..a23b043a0 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -179,3 +179,92 @@ func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<10 return } +// ----- +// DS operations tests + +// CHECK-LABEL: func.func @ds_load_b32 +func.func @ds_load_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<1xf32> { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> + return %result : vector<1xf32> +} + +// CHECK-LABEL: func.func @ds_load_b64 +func.func @ds_load_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<2xf32> { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_read_b64 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> + return %result : vector<2xf32> +} + +// CHECK-LABEL: func.func @ds_load_b96 +func.func @ds_load_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<3xf32> { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_read_b96 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> + return %result : vector<3xf32> +} + +// CHECK-LABEL: func.func @ds_load_b128 +func.func @ds_load_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<4xf32> { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" + %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + return %result : vector<4xf32> +} + +// CHECK-LABEL: func.func @ds_store_b32 +func.func @ds_store_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<1xf32>) { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b64 +func.func @ds_store_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<2xf32>) { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_write_b64 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b96 +func.func @ds_store_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<3xf32>) { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_write_b96 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> + return +} + +// CHECK-LABEL: func.func @ds_store_b128 +func.func @ds_store_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<4xf32>) { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" + vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @mixed_global_buffer_and_ds +func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %lds: memref<1024xf32, #gpu.address_space>, %offset: index) { + // Load from global (should use global_load) + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> + + // Store to LDS (should use ds_write) + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" + vector.store %global_data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Load from LDS (should use ds_read) + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" + %lds_data = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Store to buffer (should use buffer_store) + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + vector.store %lds_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> + + return +} From 39ecaa028b89b222637b67ec9580b055a332685c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Dec 2025 23:32:35 +0100 Subject: [PATCH 051/114] nicer Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 37 +++++++++++--------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index b964827d4..d7a839853 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -430,29 +430,34 @@ class WaterLowerMemoryOpsPass void runOnOperation() override { IRRewriter rewriter(&getContext()); + // Helper to dispatch to the appropriate lowering function based on address + // space + auto lowerMemoryOp = [&](Value base, auto lowerBuffer, auto lowerWorkgroup, + auto lowerGlobal) -> LogicalResult { + if (usesBufferAddressSpace(base)) + return lowerBuffer(); + if (usesWorkgroupAddressSpace(base)) + return lowerWorkgroup(); + return lowerGlobal(); + }; + auto walkFn = [&](Operation *op) { if (auto loadOp = dyn_cast(op)) { - LogicalResult result = success(); - if (usesBufferAddressSpace(loadOp.getBase())) { - result = lowerVectorLoadBuffer(loadOp, rewriter); - } else if (usesWorkgroupAddressSpace(loadOp.getBase())) { - result = lowerVectorLoadDS(loadOp, rewriter); - } else { - result = lowerVectorLoadGlobal(loadOp, rewriter); - } + LogicalResult result = lowerMemoryOp( + loadOp.getBase(), + [&]() { return lowerVectorLoadBuffer(loadOp, rewriter); }, + [&]() { return lowerVectorLoadDS(loadOp, rewriter); }, + [&]() { return lowerVectorLoadGlobal(loadOp, rewriter); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto storeOp = dyn_cast(op)) { - LogicalResult result = success(); - if (usesBufferAddressSpace(storeOp.getBase())) { - result = lowerVectorStoreBuffer(storeOp, rewriter); - } else if (usesWorkgroupAddressSpace(storeOp.getBase())) { - result = lowerVectorStoreDS(storeOp, rewriter); - } else { - result = lowerVectorStoreGlobal(storeOp, rewriter); - } + LogicalResult result = lowerMemoryOp( + storeOp.getBase(), + [&]() { return lowerVectorStoreBuffer(storeOp, rewriter); }, + [&]() { return lowerVectorStoreDS(storeOp, rewriter); }, + [&]() { return lowerVectorStoreGlobal(storeOp, rewriter); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); From 351f78f0c5a5a379f3c5db49dce57e8e2769751c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 15:38:28 +0100 Subject: [PATCH 052/114] memref.load/store Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 194 ++++++++++++------- water/test/Transforms/lower-memory-ops.mlir | 82 ++++++++ 2 files changed, 205 insertions(+), 71 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index d7a839853..3c02b2845 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -161,17 +161,44 @@ static Value extractBufferDescriptor(IRRewriter &rewriter, Location loc, return memrefDesc.alignedPtr(rewriter, loc); } -/// Lower vector.load to AMDGPU buffer load inline assembly -static LogicalResult lowerVectorLoadBuffer(vector::LoadOp loadOp, - IRRewriter &rewriter) { - auto vectorType = loadOp.getVectorType(); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); +/// Helper to get memref, result type, and bit width from load operation +template +static std::tuple getLoadOpInfo(LoadOpTy loadOp) { + if constexpr (std::is_same_v) { + auto vectorType = loadOp.getVectorType(); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + return {loadOp.getBase(), vectorType, bitWidth}; + } else { + auto elementType = loadOp.getResult().getType(); + unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + return {loadOp.getMemRef(), elementType, bitWidth}; + } +} + +/// Helper to get memref, value type, and bit width from store operation +template +static std::tuple getStoreOpInfo(StoreOpTy storeOp) { + if constexpr (std::is_same_v) { + auto vectorType = cast(storeOp.getValueToStore().getType()); + unsigned bitWidth = + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + return {storeOp.getBase(), vectorType, bitWidth}; + } else { + auto elementType = storeOp.getValueToStore().getType(); + unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + return {storeOp.getMemRef(), elementType, bitWidth}; + } +} + +/// Lower vector/scalar load to AMDGPU buffer load inline assembly +template +static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); FailureOr suffix = getBufferSuffix(bitWidth); if (failed(suffix)) - return loadOp.emitError("unsupported vector buffer load bit width: ") - << bitWidth; + return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; Location loc = loadOp.getLoc(); rewriter.setInsertionPoint(loadOp); @@ -186,32 +213,33 @@ static LogicalResult lowerVectorLoadBuffer(vector::LoadOp loadOp, // Compute byte offset as i64 (not full address, since buffer descriptor has // base) + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; Value offset = computeMemrefByteOffsetI64( - rewriter, loc, loadOp.getBase(), loadOp.getIndices(), - vectorType.getElementTypeBitWidth()); + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); // Extract buffer descriptor pointer from memref - Value bufferDesc = extractBufferDescriptor(rewriter, loc, loadOp.getBase()); + Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); // Create inline assembly operation auto asmOp = - createInlineAsm(rewriter, loc, vectorType, ValueRange{offset, bufferDesc}, + createInlineAsm(rewriter, loc, resultType, ValueRange{offset, bufferDesc}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); } -/// Lower vector.load to LLVM inline assembly (global_load_*) -static LogicalResult lowerVectorLoadGlobal(vector::LoadOp loadOp, - IRRewriter &rewriter) { - auto vectorType = loadOp.getVectorType(); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); +/// Lower vector/scalar load to LLVM inline assembly (global_load_*) +template +static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); FailureOr suffix = getSizeSuffix(bitWidth); if (failed(suffix)) - return loadOp.emitError("unsupported vector load bit width: ") << bitWidth; + return loadOp.emitError("unsupported load bit width: ") << bitWidth; Location loc = loadOp.getLoc(); @@ -224,28 +252,29 @@ static LogicalResult lowerVectorLoadGlobal(vector::LoadOp loadOp, rewriter.setInsertionPoint(loadOp); // Compute the final address - Value addr = - computeMemrefAddress(rewriter, loc, loadOp.getBase(), loadOp.getIndices(), - vectorType.getElementTypeBitWidth()); + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value addr = computeMemrefAddress(rewriter, loc, memref, loadOp.getIndices(), + elementBitWidth); // Create the inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, vectorType, ValueRange{addr}, + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); } -/// Lower vector.store to AMDGPU buffer store inline assembly -static LogicalResult lowerVectorStoreBuffer(vector::StoreOp storeOp, - IRRewriter &rewriter) { - auto vectorType = cast(storeOp.getValueToStore().getType()); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); +/// Lower vector/scalar store to AMDGPU buffer store inline assembly +template +static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); FailureOr suffix = getBufferSuffix(bitWidth); if (failed(suffix)) - return storeOp.emitError("unsupported vector buffer store bit width: ") + return storeOp.emitError("unsupported buffer store bit width: ") << bitWidth; Location loc = storeOp.getLoc(); @@ -261,12 +290,15 @@ static LogicalResult lowerVectorStoreBuffer(vector::StoreOp storeOp, // Compute byte offset as i64 (not full address, since buffer descriptor has // base) + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; Value offset = computeMemrefByteOffsetI64( - rewriter, loc, storeOp.getBase(), storeOp.getIndices(), - vectorType.getElementTypeBitWidth()); + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); // Extract buffer descriptor pointer from memref - Value bufferDesc = extractBufferDescriptor(rewriter, loc, storeOp.getBase()); + Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, @@ -277,17 +309,14 @@ static LogicalResult lowerVectorStoreBuffer(vector::StoreOp storeOp, return success(); } -/// Lower vector.store to LLVM inline assembly (global_store_*) -static LogicalResult lowerVectorStoreGlobal(vector::StoreOp storeOp, - IRRewriter &rewriter) { - auto vectorType = cast(storeOp.getValueToStore().getType()); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); +/// Lower vector/scalar store to LLVM inline assembly (global_store_*) +template +static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); FailureOr suffix = getSizeSuffix(bitWidth); if (failed(suffix)) - return storeOp.emitError("unsupported vector store bit width: ") - << bitWidth; + return storeOp.emitError("unsupported store bit width: ") << bitWidth; Location loc = storeOp.getLoc(); @@ -300,9 +329,12 @@ static LogicalResult lowerVectorStoreGlobal(vector::StoreOp storeOp, rewriter.setInsertionPoint(storeOp); // Compute the final address - Value addr = computeMemrefAddress(rewriter, loc, storeOp.getBase(), - storeOp.getIndices(), - vectorType.getElementTypeBitWidth()); + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; + Value addr = computeMemrefAddress(rewriter, loc, memref, storeOp.getIndices(), + elementBitWidth); // Create the inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, @@ -313,17 +345,14 @@ static LogicalResult lowerVectorStoreGlobal(vector::StoreOp storeOp, return success(); } -/// Lower vector.load to AMDGPU DS load inline assembly -static LogicalResult lowerVectorLoadDS(vector::LoadOp loadOp, - IRRewriter &rewriter) { - auto vectorType = loadOp.getVectorType(); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); +/// Lower vector/scalar load to AMDGPU DS load inline assembly +template +static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); FailureOr suffix = getSizeSuffix(bitWidth); if (failed(suffix)) - return loadOp.emitError("unsupported vector DS load bit width: ") - << bitWidth; + return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; Location loc = loadOp.getLoc(); rewriter.setInsertionPoint(loadOp); @@ -335,33 +364,33 @@ static LogicalResult lowerVectorLoadDS(vector::LoadOp loadOp, StringRef constraints = "=v,v"; // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; Value offset = computeMemrefByteOffsetI64( - rewriter, loc, loadOp.getBase(), loadOp.getIndices(), - vectorType.getElementTypeBitWidth()); + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); // DS operations use 32-bit addresses Value offset32 = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); // Create inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, vectorType, ValueRange{offset32}, + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset32}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); } -/// Lower vector.store to AMDGPU DS store inline assembly -static LogicalResult lowerVectorStoreDS(vector::StoreOp storeOp, - IRRewriter &rewriter) { - auto vectorType = cast(storeOp.getValueToStore().getType()); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); +/// Lower vector/scalar store to AMDGPU DS store inline assembly +template +static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { + auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); FailureOr suffix = getSizeSuffix(bitWidth); if (failed(suffix)) - return storeOp.emitError("unsupported vector DS store bit width: ") - << bitWidth; + return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; Location loc = storeOp.getLoc(); rewriter.setInsertionPoint(storeOp); @@ -373,9 +402,12 @@ static LogicalResult lowerVectorStoreDS(vector::StoreOp storeOp, StringRef constraints = "v,v"; // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(valueType).getElementTypeBitWidth() + : bitWidth; Value offset = computeMemrefByteOffsetI64( - rewriter, loc, storeOp.getBase(), storeOp.getIndices(), - vectorType.getElementTypeBitWidth()); + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); // DS operations use 32-bit addresses Value offset32 = @@ -445,9 +477,9 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getBase(), - [&]() { return lowerVectorLoadBuffer(loadOp, rewriter); }, - [&]() { return lowerVectorLoadDS(loadOp, rewriter); }, - [&]() { return lowerVectorLoadGlobal(loadOp, rewriter); }); + [&]() { return lowerLoadBuffer(loadOp, rewriter); }, + [&]() { return lowerLoadDS(loadOp, rewriter); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); @@ -455,9 +487,29 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getBase(), - [&]() { return lowerVectorStoreBuffer(storeOp, rewriter); }, - [&]() { return lowerVectorStoreDS(storeOp, rewriter); }, - [&]() { return lowerVectorStoreGlobal(storeOp, rewriter); }); + [&]() { return lowerStoreBuffer(storeOp, rewriter); }, + [&]() { return lowerStoreDS(storeOp, rewriter); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto loadOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + loadOp.getMemRef(), + [&]() { return lowerLoadBuffer(loadOp, rewriter); }, + [&]() { return lowerLoadDS(loadOp, rewriter); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter); }); + if (failed(result)) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto storeOp = dyn_cast(op)) { + LogicalResult result = lowerMemoryOp( + storeOp.getMemRef(), + [&]() { return lowerStoreBuffer(storeOp, rewriter); }, + [&]() { return lowerStoreDS(storeOp, rewriter); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index a23b043a0..f5359142a 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -268,3 +268,85 @@ func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref return } + +// ----- +// Scalar (memref) operations tests + +// CHECK-LABEL: func.func @scalar_load_global_f32 +func.func @scalar_load_global_f32(%memref: memref<1024xf32>, %offset: index) -> f32 { + // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %result = memref.load %memref[%offset] : memref<1024xf32> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_load_global_f64 +func.func @scalar_load_global_f64(%memref: memref<1024xf64>, %offset: index) -> f64 { + // CHECK: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + %result = memref.load %memref[%offset] : memref<1024xf64> + return %result : f64 +} + +// CHECK-LABEL: func.func @scalar_store_global_f32 +func.func @scalar_store_global_f32(%memref: memref<1024xf32>, %offset: index, %data: f32) { + // CHECK: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xf32> + return +} + +// CHECK-LABEL: func.func @scalar_store_global_f64 +func.func @scalar_store_global_f64(%memref: memref<1024xf64>, %offset: index, %data: f64) { + // CHECK: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xf64> + return +} + +// CHECK-LABEL: func.func @scalar_load_buffer_f32 +func.func @scalar_load_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> f32 { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + %result = memref.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_store_buffer_f32 +func.func @scalar_store_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: f32) { + // CHECK: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + memref.store %data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> + return +} + +// CHECK-LABEL: func.func @scalar_load_ds_f32 +func.func @scalar_load_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> f32 { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" + %result = memref.load %lds[%offset] : memref<1024xf32, #gpu.address_space> + return %result : f32 +} + +// CHECK-LABEL: func.func @scalar_store_ds_f32 +func.func @scalar_store_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: f32) { + // CHECK: arith.trunci + // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" + memref.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @mixed_scalar_and_vector +func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { + // Scalar load + // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + %scalar = memref.load %memref[%offset] : memref<1024xf32> + + // Vector load + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + %vector = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> + + // Scalar store + // CHECK: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %scalar, %memref[%offset] : memref<1024xf32> + + // Vector store + // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %vector, %memref[%offset] : memref<1024xf32>, vector<4xf32> + + return +} From 2bc4e39912e1b8cd2a3a92d4e9a357dfc52926b8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 16:05:09 +0100 Subject: [PATCH 053/114] 8 and 16 bits Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 158 ++++++++++++++++--- water/test/Transforms/lower-memory-ops.mlir | 99 ++++++++++++ 2 files changed, 231 insertions(+), 26 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 3c02b2845..d9dd889b4 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -25,9 +25,33 @@ namespace mlir::water { namespace { -/// Get the AMDGPU instruction suffix based on bit width -static FailureOr getSizeSuffix(unsigned bitWidth) { +/// Get the AMDGPU instruction suffix based on bit width (for loads - unsigned) +static FailureOr getSizeSuffixLoad(unsigned bitWidth) { switch (bitWidth) { + case 8: + return StringRef("u8"); + case 16: + return StringRef("u16"); + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } +} + +/// Get the AMDGPU instruction suffix based on bit width (for stores) +static FailureOr getSizeSuffixStore(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); case 32: return StringRef("b32"); case 64: @@ -116,9 +140,33 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); } -/// Get buffer instruction suffix based on bit width -static FailureOr getBufferSuffix(unsigned bitWidth) { +/// Get buffer instruction suffix based on bit width (for loads - unsigned) +static FailureOr getBufferSuffixLoad(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("ubyte"); + case 16: + return StringRef("ushort"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +/// Get buffer instruction suffix based on bit width (for stores) +static FailureOr getBufferSuffixStore(unsigned bitWidth) { switch (bitWidth) { + case 8: + return StringRef("byte"); + case 16: + return StringRef("short"); case 32: return StringRef("dword"); case 64: @@ -196,7 +244,7 @@ template static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - FailureOr suffix = getBufferSuffix(bitWidth); + FailureOr suffix = getBufferSuffixLoad(bitWidth); if (failed(suffix)) return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; @@ -223,12 +271,24 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); + // For sub-32-bit loads, hardware loads into i32 and we need to truncate + Type asmResultType = resultType; + if (bitWidth < 32 && !isa(resultType)) { + asmResultType = rewriter.getI32Type(); + } + // Create inline assembly operation - auto asmOp = - createInlineAsm(rewriter, loc, resultType, ValueRange{offset, bufferDesc}, - asmStr, constraints, /*hasSideEffects=*/true); + auto asmOp = createInlineAsm(rewriter, loc, asmResultType, + ValueRange{offset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); + + Value result = asmOp.getResult(0); + // Truncate if needed for sub-32-bit scalar types + if (bitWidth < 32 && !isa(resultType)) { + result = arith::TruncIOp::create(rewriter, loc, resultType, result); + } - rewriter.replaceOp(loadOp, asmOp.getResult(0)); + rewriter.replaceOp(loadOp, result); return success(); } @@ -237,7 +297,7 @@ template static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - FailureOr suffix = getSizeSuffix(bitWidth); + FailureOr suffix = getSizeSuffixLoad(bitWidth); if (failed(suffix)) return loadOp.emitError("unsupported load bit width: ") << bitWidth; @@ -259,11 +319,23 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress(rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + // For sub-32-bit loads, hardware loads into i32 and we need to truncate + Type asmResultType = resultType; + if (bitWidth < 32 && !isa(resultType)) { + asmResultType = rewriter.getI32Type(); + } + // Create the inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, + auto asmOp = createInlineAsm(rewriter, loc, asmResultType, ValueRange{addr}, asmStr, constraints, /*hasSideEffects=*/true); - rewriter.replaceOp(loadOp, asmOp.getResult(0)); + Value result = asmOp.getResult(0); + // Truncate if needed for sub-32-bit scalar types + if (bitWidth < 32 && !isa(resultType)) { + result = arith::TruncIOp::create(rewriter, loc, resultType, result); + } + + rewriter.replaceOp(loadOp, result); return success(); } @@ -272,7 +344,7 @@ template static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - FailureOr suffix = getBufferSuffix(bitWidth); + FailureOr suffix = getBufferSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported buffer store bit width: ") << bitWidth; @@ -300,10 +372,17 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); + // For sub-32-bit stores, extend to i32 first + Value valueToStore = storeOp.getValueToStore(); + if (bitWidth < 32 && !isa(valueType)) { + valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), + valueToStore); + } + // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{storeOp.getValueToStore(), offset, bufferDesc}, - asmStr, constraints, /*hasSideEffects=*/true); + ValueRange{valueToStore, offset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); return success(); @@ -314,7 +393,7 @@ template static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - FailureOr suffix = getSizeSuffix(bitWidth); + FailureOr suffix = getSizeSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported store bit width: ") << bitWidth; @@ -336,10 +415,17 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress(rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); + // For sub-32-bit stores, extend to i32 first + Value valueToStore = storeOp.getValueToStore(); + if (bitWidth < 32 && !isa(valueType)) { + valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), + valueToStore); + } + // Create the inline assembly operation (no result for store) - createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{addr, storeOp.getValueToStore()}, asmStr, - constraints, /*hasSideEffects=*/true); + createInlineAsm(rewriter, loc, TypeRange{}, ValueRange{addr, valueToStore}, + asmStr, constraints, + /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); return success(); @@ -350,7 +436,7 @@ template static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - FailureOr suffix = getSizeSuffix(bitWidth); + FailureOr suffix = getSizeSuffixLoad(bitWidth); if (failed(suffix)) return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; @@ -375,11 +461,24 @@ static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { Value offset32 = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); + // For sub-32-bit loads, hardware loads into i32 and we need to truncate + Type asmResultType = resultType; + if (bitWidth < 32 && !isa(resultType)) { + asmResultType = rewriter.getI32Type(); + } + // Create inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset32}, - asmStr, constraints, /*hasSideEffects=*/true); + auto asmOp = + createInlineAsm(rewriter, loc, asmResultType, ValueRange{offset32}, + asmStr, constraints, /*hasSideEffects=*/true); + + Value result = asmOp.getResult(0); + // Truncate if needed for sub-32-bit scalar types + if (bitWidth < 32 && !isa(resultType)) { + result = arith::TruncIOp::create(rewriter, loc, resultType, result); + } - rewriter.replaceOp(loadOp, asmOp.getResult(0)); + rewriter.replaceOp(loadOp, result); return success(); } @@ -388,7 +487,7 @@ template static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - FailureOr suffix = getSizeSuffix(bitWidth); + FailureOr suffix = getSizeSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; @@ -413,10 +512,17 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { Value offset32 = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); + // For sub-32-bit stores, extend to i32 first + Value valueToStore = storeOp.getValueToStore(); + if (bitWidth < 32 && !isa(valueType)) { + valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), + valueToStore); + } + // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{offset32, storeOp.getValueToStore()}, asmStr, - constraints, /*hasSideEffects=*/true); + ValueRange{offset32, valueToStore}, asmStr, constraints, + /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); return success(); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index f5359142a..1c76c199e 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -350,3 +350,102 @@ func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { return } + +// ----- +// 8-bit and 16-bit operations tests + +// CHECK-LABEL: func.func @scalar_load_global_i8 +func.func @scalar_load_global_i8(%memref: memref<1024xi8>, %offset: index) -> i8 { + // CHECK: llvm.inline_asm has_side_effects "global_load_u8 $0, $1, off", "=v,v" + // CHECK: arith.trunci + %result = memref.load %memref[%offset] : memref<1024xi8> + return %result : i8 +} + +// CHECK-LABEL: func.func @scalar_load_global_i16 +func.func @scalar_load_global_i16(%memref: memref<1024xi16>, %offset: index) -> i16 { + // CHECK: llvm.inline_asm has_side_effects "global_load_u16 $0, $1, off", "=v,v" + // CHECK: arith.trunci + %result = memref.load %memref[%offset] : memref<1024xi16> + return %result : i16 +} + +// CHECK-LABEL: func.func @scalar_store_global_i8 +func.func @scalar_store_global_i8(%memref: memref<1024xi8>, %offset: index, %data: i8) { + // CHECK: arith.extui + // CHECK: llvm.inline_asm has_side_effects "global_store_b8 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xi8> + return +} + +// CHECK-LABEL: func.func @scalar_store_global_i16 +func.func @scalar_store_global_i16(%memref: memref<1024xi16>, %offset: index, %data: i16) { + // CHECK: arith.extui + // CHECK: llvm.inline_asm has_side_effects "global_store_b16 $0, $1, off", "v,v" + memref.store %data, %memref[%offset] : memref<1024xi16> + return +} + +// CHECK-LABEL: func.func @scalar_load_buffer_i8 +func.func @scalar_load_buffer_i8(%buffer: memref<1024xi8, #amdgpu.address_space>, %offset: index) -> i8 { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_ubyte $0, $1, $2, 0 offen", "=v,v,s" + // CHECK: arith.trunci + %result = memref.load %buffer[%offset] : memref<1024xi8, #amdgpu.address_space> + return %result : i8 +} + +// CHECK-LABEL: func.func @scalar_load_buffer_i16 +func.func @scalar_load_buffer_i16(%buffer: memref<1024xi16, #amdgpu.address_space>, %offset: index) -> i16 { + // CHECK: llvm.inline_asm has_side_effects "buffer_load_ushort $0, $1, $2, 0 offen", "=v,v,s" + // CHECK: arith.trunci + %result = memref.load %buffer[%offset] : memref<1024xi16, #amdgpu.address_space> + return %result : i16 +} + +// CHECK-LABEL: func.func @scalar_store_buffer_i8 +func.func @scalar_store_buffer_i8(%buffer: memref<1024xi8, #amdgpu.address_space>, %offset: index, %data: i8) { + // CHECK: arith.extui + // CHECK: llvm.inline_asm has_side_effects "buffer_store_byte $0, $1, $2, 0 offen", "v,v,s" + memref.store %data, %buffer[%offset] : memref<1024xi8, #amdgpu.address_space> + return +} + +// CHECK-LABEL: func.func @scalar_store_buffer_i16 +func.func @scalar_store_buffer_i16(%buffer: memref<1024xi16, #amdgpu.address_space>, %offset: index, %data: i16) { + // CHECK: arith.extui + // CHECK: llvm.inline_asm has_side_effects "buffer_store_short $0, $1, $2, 0 offen", "v,v,s" + memref.store %data, %buffer[%offset] : memref<1024xi16, #amdgpu.address_space> + return +} + +// CHECK-LABEL: func.func @scalar_load_ds_i8 +func.func @scalar_load_ds_i8(%lds: memref<1024xi8, #gpu.address_space>, %offset: index) -> i8 { + // CHECK: llvm.inline_asm has_side_effects "ds_read_u8 $0, $1", "=v,v" + // CHECK: arith.trunci + %result = memref.load %lds[%offset] : memref<1024xi8, #gpu.address_space> + return %result : i8 +} + +// CHECK-LABEL: func.func @scalar_load_ds_i16 +func.func @scalar_load_ds_i16(%lds: memref<1024xi16, #gpu.address_space>, %offset: index) -> i16 { + // CHECK: llvm.inline_asm has_side_effects "ds_read_u16 $0, $1", "=v,v" + // CHECK: arith.trunci + %result = memref.load %lds[%offset] : memref<1024xi16, #gpu.address_space> + return %result : i16 +} + +// CHECK-LABEL: func.func @scalar_store_ds_i8 +func.func @scalar_store_ds_i8(%lds: memref<1024xi8, #gpu.address_space>, %offset: index, %data: i8) { + // CHECK: arith.extui + // CHECK: llvm.inline_asm has_side_effects "ds_write_b8 $0, $1", "v,v" + memref.store %data, %lds[%offset] : memref<1024xi8, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @scalar_store_ds_i16 +func.func @scalar_store_ds_i16(%lds: memref<1024xi16, #gpu.address_space>, %offset: index, %data: i16) { + // CHECK: arith.extui + // CHECK: llvm.inline_asm has_side_effects "ds_write_b16 $0, $1", "v,v" + memref.store %data, %lds[%offset] : memref<1024xi16, #gpu.address_space> + return +} From 0bbccc7ebf7025b70014b723455e21fec9485dc1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 16:19:51 +0100 Subject: [PATCH 054/114] small bvalues fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 103 +++++++++++++------ 1 file changed, 70 insertions(+), 33 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index d9dd889b4..bfd77c480 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -272,20 +272,28 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); // For sub-32-bit loads, hardware loads into i32 and we need to truncate - Type asmResultType = resultType; - if (bitWidth < 32 && !isa(resultType)) { - asmResultType = rewriter.getI32Type(); - } - - // Create inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, asmResultType, - ValueRange{offset, bufferDesc}, asmStr, - constraints, /*hasSideEffects=*/true); - - Value result = asmOp.getResult(0); - // Truncate if needed for sub-32-bit scalar types - if (bitWidth < 32 && !isa(resultType)) { - result = arith::TruncIOp::create(rewriter, loc, resultType, result); + Value result; + if (bitWidth < 32) { + // Create inline asm returning i32 + auto asmOp = createInlineAsm(rewriter, loc, rewriter.getI32Type(), + ValueRange{offset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); + + // Truncate to appropriate integer width + Type intType = rewriter.getIntegerType(bitWidth); + result = + arith::TruncIOp::create(rewriter, loc, intType, asmOp.getResult(0)); + + // Bitcast to actual result type (handles floats and vectors) + if (resultType != intType) { + result = LLVM::BitcastOp::create(rewriter, loc, resultType, result); + } + } else { + // Create inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, + ValueRange{offset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); + result = asmOp.getResult(0); } rewriter.replaceOp(loadOp, result); @@ -320,19 +328,27 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { elementBitWidth); // For sub-32-bit loads, hardware loads into i32 and we need to truncate - Type asmResultType = resultType; - if (bitWidth < 32 && !isa(resultType)) { - asmResultType = rewriter.getI32Type(); - } - - // Create the inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, asmResultType, ValueRange{addr}, - asmStr, constraints, /*hasSideEffects=*/true); - - Value result = asmOp.getResult(0); - // Truncate if needed for sub-32-bit scalar types - if (bitWidth < 32 && !isa(resultType)) { - result = arith::TruncIOp::create(rewriter, loc, resultType, result); + Value result; + if (bitWidth < 32) { + // Create inline asm returning i32 + auto asmOp = createInlineAsm(rewriter, loc, rewriter.getI32Type(), + ValueRange{addr}, asmStr, constraints, + /*hasSideEffects=*/true); + + // Truncate to appropriate integer width + Type intType = rewriter.getIntegerType(bitWidth); + result = + arith::TruncIOp::create(rewriter, loc, intType, asmOp.getResult(0)); + + // Bitcast to actual result type (handles floats and vectors) + if (resultType != intType) { + result = LLVM::BitcastOp::create(rewriter, loc, resultType, result); + } + } else { + // Create the inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, + asmStr, constraints, /*hasSideEffects=*/true); + result = asmOp.getResult(0); } rewriter.replaceOp(loadOp, result); @@ -372,9 +388,16 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); - // For sub-32-bit stores, extend to i32 first + // For sub-32-bit stores, bitcast to int and extend to i32 Value valueToStore = storeOp.getValueToStore(); - if (bitWidth < 32 && !isa(valueType)) { + if (bitWidth < 32) { + // Bitcast to integer type (handles floats and vectors) + Type intType = rewriter.getIntegerType(bitWidth); + if (valueType != intType) { + valueToStore = + LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); + } + // Extend to i32 valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), valueToStore); } @@ -415,9 +438,16 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress(rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - // For sub-32-bit stores, extend to i32 first + // For sub-32-bit stores, bitcast to int and extend to i32 Value valueToStore = storeOp.getValueToStore(); - if (bitWidth < 32 && !isa(valueType)) { + if (bitWidth < 32) { + // Bitcast to integer type (handles floats and vectors) + Type intType = rewriter.getIntegerType(bitWidth); + if (valueType != intType) { + valueToStore = + LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); + } + // Extend to i32 valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), valueToStore); } @@ -512,9 +542,16 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { Value offset32 = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); - // For sub-32-bit stores, extend to i32 first + // For sub-32-bit stores, bitcast to int and extend to i32 Value valueToStore = storeOp.getValueToStore(); - if (bitWidth < 32 && !isa(valueType)) { + if (bitWidth < 32) { + // Bitcast to integer type (handles floats and vectors) + Type intType = rewriter.getIntegerType(bitWidth); + if (valueType != intType) { + valueToStore = + LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); + } + // Extend to i32 valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), valueToStore); } From 2b19267f4d9965dd0616ebf88033a8b513df43f4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 17:50:34 +0100 Subject: [PATCH 055/114] revert 32 bit Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 131 +++++-------------- water/test/Transforms/lower-memory-ops.mlir | 99 -------------- 2 files changed, 30 insertions(+), 200 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index bfd77c480..7fe0db9e5 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -244,6 +244,9 @@ template static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + if (bitWidth < 32) + return success(); + FailureOr suffix = getBufferSuffixLoad(bitWidth); if (failed(suffix)) return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; @@ -271,32 +274,12 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); - // For sub-32-bit loads, hardware loads into i32 and we need to truncate - Value result; - if (bitWidth < 32) { - // Create inline asm returning i32 - auto asmOp = createInlineAsm(rewriter, loc, rewriter.getI32Type(), - ValueRange{offset, bufferDesc}, asmStr, - constraints, /*hasSideEffects=*/true); - - // Truncate to appropriate integer width - Type intType = rewriter.getIntegerType(bitWidth); - result = - arith::TruncIOp::create(rewriter, loc, intType, asmOp.getResult(0)); - - // Bitcast to actual result type (handles floats and vectors) - if (resultType != intType) { - result = LLVM::BitcastOp::create(rewriter, loc, resultType, result); - } - } else { - // Create inline assembly operation with result type directly - auto asmOp = createInlineAsm(rewriter, loc, resultType, - ValueRange{offset, bufferDesc}, asmStr, - constraints, /*hasSideEffects=*/true); - result = asmOp.getResult(0); - } + // Create inline assembly operation with result type directly + auto asmOp = + createInlineAsm(rewriter, loc, resultType, ValueRange{offset, bufferDesc}, + asmStr, constraints, /*hasSideEffects=*/true); - rewriter.replaceOp(loadOp, result); + rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); } @@ -305,6 +288,9 @@ template static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + if (bitWidth < 32) + return success(); + FailureOr suffix = getSizeSuffixLoad(bitWidth); if (failed(suffix)) return loadOp.emitError("unsupported load bit width: ") << bitWidth; @@ -327,31 +313,11 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress(rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - // For sub-32-bit loads, hardware loads into i32 and we need to truncate - Value result; - if (bitWidth < 32) { - // Create inline asm returning i32 - auto asmOp = createInlineAsm(rewriter, loc, rewriter.getI32Type(), - ValueRange{addr}, asmStr, constraints, - /*hasSideEffects=*/true); - - // Truncate to appropriate integer width - Type intType = rewriter.getIntegerType(bitWidth); - result = - arith::TruncIOp::create(rewriter, loc, intType, asmOp.getResult(0)); - - // Bitcast to actual result type (handles floats and vectors) - if (resultType != intType) { - result = LLVM::BitcastOp::create(rewriter, loc, resultType, result); - } - } else { - // Create the inline assembly operation with result type directly - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, - asmStr, constraints, /*hasSideEffects=*/true); - result = asmOp.getResult(0); - } + // Create the inline assembly operation with result type directly + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, + asmStr, constraints, /*hasSideEffects=*/true); - rewriter.replaceOp(loadOp, result); + rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); } @@ -360,6 +326,9 @@ template static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + if (bitWidth < 32) + return success(); + FailureOr suffix = getBufferSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported buffer store bit width: ") @@ -388,19 +357,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { // Extract buffer descriptor pointer from memref Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); - // For sub-32-bit stores, bitcast to int and extend to i32 Value valueToStore = storeOp.getValueToStore(); - if (bitWidth < 32) { - // Bitcast to integer type (handles floats and vectors) - Type intType = rewriter.getIntegerType(bitWidth); - if (valueType != intType) { - valueToStore = - LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); - } - // Extend to i32 - valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), - valueToStore); - } // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, @@ -416,6 +373,9 @@ template static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + if (bitWidth < 32) + return success(); + FailureOr suffix = getSizeSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported store bit width: ") << bitWidth; @@ -438,19 +398,7 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress(rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - // For sub-32-bit stores, bitcast to int and extend to i32 Value valueToStore = storeOp.getValueToStore(); - if (bitWidth < 32) { - // Bitcast to integer type (handles floats and vectors) - Type intType = rewriter.getIntegerType(bitWidth); - if (valueType != intType) { - valueToStore = - LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); - } - // Extend to i32 - valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), - valueToStore); - } // Create the inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, ValueRange{addr, valueToStore}, @@ -466,6 +414,9 @@ template static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + if (bitWidth < 32) + return success(); + FailureOr suffix = getSizeSuffixLoad(bitWidth); if (failed(suffix)) return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; @@ -491,24 +442,11 @@ static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { Value offset32 = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); - // For sub-32-bit loads, hardware loads into i32 and we need to truncate - Type asmResultType = resultType; - if (bitWidth < 32 && !isa(resultType)) { - asmResultType = rewriter.getI32Type(); - } - // Create inline assembly operation - auto asmOp = - createInlineAsm(rewriter, loc, asmResultType, ValueRange{offset32}, - asmStr, constraints, /*hasSideEffects=*/true); - - Value result = asmOp.getResult(0); - // Truncate if needed for sub-32-bit scalar types - if (bitWidth < 32 && !isa(resultType)) { - result = arith::TruncIOp::create(rewriter, loc, resultType, result); - } + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset32}, + asmStr, constraints, /*hasSideEffects=*/true); - rewriter.replaceOp(loadOp, result); + rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); } @@ -517,6 +455,9 @@ template static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); + if (bitWidth < 32) + return success(); + FailureOr suffix = getSizeSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; @@ -542,19 +483,7 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { Value offset32 = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); - // For sub-32-bit stores, bitcast to int and extend to i32 Value valueToStore = storeOp.getValueToStore(); - if (bitWidth < 32) { - // Bitcast to integer type (handles floats and vectors) - Type intType = rewriter.getIntegerType(bitWidth); - if (valueType != intType) { - valueToStore = - LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); - } - // Extend to i32 - valueToStore = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), - valueToStore); - } // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 1c76c199e..f5359142a 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -350,102 +350,3 @@ func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { return } - -// ----- -// 8-bit and 16-bit operations tests - -// CHECK-LABEL: func.func @scalar_load_global_i8 -func.func @scalar_load_global_i8(%memref: memref<1024xi8>, %offset: index) -> i8 { - // CHECK: llvm.inline_asm has_side_effects "global_load_u8 $0, $1, off", "=v,v" - // CHECK: arith.trunci - %result = memref.load %memref[%offset] : memref<1024xi8> - return %result : i8 -} - -// CHECK-LABEL: func.func @scalar_load_global_i16 -func.func @scalar_load_global_i16(%memref: memref<1024xi16>, %offset: index) -> i16 { - // CHECK: llvm.inline_asm has_side_effects "global_load_u16 $0, $1, off", "=v,v" - // CHECK: arith.trunci - %result = memref.load %memref[%offset] : memref<1024xi16> - return %result : i16 -} - -// CHECK-LABEL: func.func @scalar_store_global_i8 -func.func @scalar_store_global_i8(%memref: memref<1024xi8>, %offset: index, %data: i8) { - // CHECK: arith.extui - // CHECK: llvm.inline_asm has_side_effects "global_store_b8 $0, $1, off", "v,v" - memref.store %data, %memref[%offset] : memref<1024xi8> - return -} - -// CHECK-LABEL: func.func @scalar_store_global_i16 -func.func @scalar_store_global_i16(%memref: memref<1024xi16>, %offset: index, %data: i16) { - // CHECK: arith.extui - // CHECK: llvm.inline_asm has_side_effects "global_store_b16 $0, $1, off", "v,v" - memref.store %data, %memref[%offset] : memref<1024xi16> - return -} - -// CHECK-LABEL: func.func @scalar_load_buffer_i8 -func.func @scalar_load_buffer_i8(%buffer: memref<1024xi8, #amdgpu.address_space>, %offset: index) -> i8 { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_ubyte $0, $1, $2, 0 offen", "=v,v,s" - // CHECK: arith.trunci - %result = memref.load %buffer[%offset] : memref<1024xi8, #amdgpu.address_space> - return %result : i8 -} - -// CHECK-LABEL: func.func @scalar_load_buffer_i16 -func.func @scalar_load_buffer_i16(%buffer: memref<1024xi16, #amdgpu.address_space>, %offset: index) -> i16 { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_ushort $0, $1, $2, 0 offen", "=v,v,s" - // CHECK: arith.trunci - %result = memref.load %buffer[%offset] : memref<1024xi16, #amdgpu.address_space> - return %result : i16 -} - -// CHECK-LABEL: func.func @scalar_store_buffer_i8 -func.func @scalar_store_buffer_i8(%buffer: memref<1024xi8, #amdgpu.address_space>, %offset: index, %data: i8) { - // CHECK: arith.extui - // CHECK: llvm.inline_asm has_side_effects "buffer_store_byte $0, $1, $2, 0 offen", "v,v,s" - memref.store %data, %buffer[%offset] : memref<1024xi8, #amdgpu.address_space> - return -} - -// CHECK-LABEL: func.func @scalar_store_buffer_i16 -func.func @scalar_store_buffer_i16(%buffer: memref<1024xi16, #amdgpu.address_space>, %offset: index, %data: i16) { - // CHECK: arith.extui - // CHECK: llvm.inline_asm has_side_effects "buffer_store_short $0, $1, $2, 0 offen", "v,v,s" - memref.store %data, %buffer[%offset] : memref<1024xi16, #amdgpu.address_space> - return -} - -// CHECK-LABEL: func.func @scalar_load_ds_i8 -func.func @scalar_load_ds_i8(%lds: memref<1024xi8, #gpu.address_space>, %offset: index) -> i8 { - // CHECK: llvm.inline_asm has_side_effects "ds_read_u8 $0, $1", "=v,v" - // CHECK: arith.trunci - %result = memref.load %lds[%offset] : memref<1024xi8, #gpu.address_space> - return %result : i8 -} - -// CHECK-LABEL: func.func @scalar_load_ds_i16 -func.func @scalar_load_ds_i16(%lds: memref<1024xi16, #gpu.address_space>, %offset: index) -> i16 { - // CHECK: llvm.inline_asm has_side_effects "ds_read_u16 $0, $1", "=v,v" - // CHECK: arith.trunci - %result = memref.load %lds[%offset] : memref<1024xi16, #gpu.address_space> - return %result : i16 -} - -// CHECK-LABEL: func.func @scalar_store_ds_i8 -func.func @scalar_store_ds_i8(%lds: memref<1024xi8, #gpu.address_space>, %offset: index, %data: i8) { - // CHECK: arith.extui - // CHECK: llvm.inline_asm has_side_effects "ds_write_b8 $0, $1", "v,v" - memref.store %data, %lds[%offset] : memref<1024xi8, #gpu.address_space> - return -} - -// CHECK-LABEL: func.func @scalar_store_ds_i16 -func.func @scalar_store_ds_i16(%lds: memref<1024xi16, #gpu.address_space>, %offset: index, %data: i16) { - // CHECK: arith.extui - // CHECK: llvm.inline_asm has_side_effects "ds_write_b16 $0, $1", "v,v" - memref.store %data, %lds[%offset] : memref<1024xi16, #gpu.address_space> - return -} From b8d5b8d8dfb75dadb0687f7f3a1309f6c77344be Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 17:55:05 +0100 Subject: [PATCH 056/114] chipset options Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 4 ++++ water/lib/Transforms/WaterLowerMemoryOps.cpp | 2 ++ 2 files changed, 6 insertions(+) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 92f66c812..a4db14d23 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -191,6 +191,10 @@ def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { "::mlir::memref::MemRefDialect", "::mlir::vector::VectorDialect", ]; + let options = [ + Option<"chipset", "chipset", "std::string", [{""}], + "Target chipset (e.g., gfx942, gfx1100)"> + ]; } #endif // WATER_PASSES diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 7fe0db9e5..e0e81cfa2 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -531,6 +531,8 @@ static bool usesWorkgroupAddressSpace(Value memref) { class WaterLowerMemoryOpsPass : public water::impl::WaterLowerMemoryOpsBase { public: + using Base::Base; + void runOnOperation() override { IRRewriter rewriter(&getContext()); From cbb6bdb5a25edfcca69b9e244bb33aa29cc5d663 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 21:00:12 +0100 Subject: [PATCH 057/114] buffer ops fixes Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 1 + water/lib/Transforms/WaterLowerMemoryOps.cpp | 183 ++++++++++++------- 2 files changed, 119 insertions(+), 65 deletions(-) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index a4db14d23..f8d30313b 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -189,6 +189,7 @@ def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { "::mlir::gpu::GPUDialect", "::mlir::LLVM::LLVMDialect", "::mlir::memref::MemRefDialect", + "::mlir::ROCDL::ROCDLDialect", "::mlir::vector::VectorDialect", ]; let options = [ diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index e0e81cfa2..1e6717161 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -79,10 +80,16 @@ static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, /*operand_attrs=*/ArrayAttr{}); } -/// Compute byte offset as i64 for a memref access with indices -static Value computeMemrefByteOffsetI64(IRRewriter &rewriter, Location loc, - Value memref, ValueRange indices, - unsigned elementBitWidth) { +/// Detect if chipset is RDNA architecture +static bool isRDNA(StringRef chipset) { + return chipset.starts_with("gfx11") || chipset.starts_with("gfx12"); +} + +/// Compute byte offset as iX for a memref access with indices +template +static Value computeMemrefByteOffset(IRRewriter &rewriter, Location loc, + Value memref, ValueRange indices, + unsigned elementBitWidth) { // Extract strided metadata to get offset and strides auto metadataOp = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); @@ -107,8 +114,8 @@ static Value computeMemrefByteOffsetI64(IRRewriter &rewriter, Location loc, arith::MulIOp::create(rewriter, loc, linearIndex, elementSize, arith::IntegerOverflowFlags::nsw); - return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), - byteOffset); + Type indexType = IntegerType::get(rewriter.getContext(), Bits); + return arith::IndexCastOp::create(rewriter, loc, indexType, byteOffset); } /// Compute the final address for a memref access with indices (for global @@ -130,8 +137,8 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, basePtrInt = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtrInt); // Compute byte offset - Value byteOffsetI64 = computeMemrefByteOffsetI64(rewriter, loc, memref, - indices, elementBitWidth); + Value byteOffsetI64 = computeMemrefByteOffset<64>(rewriter, loc, memref, + indices, elementBitWidth); // Add byte offset to base pointer Value finalAddr = @@ -141,42 +148,78 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, } /// Get buffer instruction suffix based on bit width (for loads - unsigned) -static FailureOr getBufferSuffixLoad(unsigned bitWidth) { - switch (bitWidth) { - case 8: - return StringRef("ubyte"); - case 16: - return StringRef("ushort"); - case 32: - return StringRef("dword"); - case 64: - return StringRef("dwordx2"); - case 96: - return StringRef("dwordx3"); - case 128: - return StringRef("dwordx4"); - default: - return failure(); +static FailureOr getBufferSuffixLoad(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + // RDNA uses b32, b64, etc. + switch (bitWidth) { + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } + } else { + // CDNA uses dword, dwordx2, etc. + switch (bitWidth) { + case 8: + return StringRef("ubyte"); + case 16: + return StringRef("ushort"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } } } /// Get buffer instruction suffix based on bit width (for stores) -static FailureOr getBufferSuffixStore(unsigned bitWidth) { - switch (bitWidth) { - case 8: - return StringRef("byte"); - case 16: - return StringRef("short"); - case 32: - return StringRef("dword"); - case 64: - return StringRef("dwordx2"); - case 96: - return StringRef("dwordx3"); - case 128: - return StringRef("dwordx4"); - default: - return failure(); +static FailureOr getBufferSuffixStore(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + // RDNA uses b32, b64, etc. + switch (bitWidth) { + case 32: + return StringRef("b32"); + case 64: + return StringRef("b64"); + case 96: + return StringRef("b96"); + case 128: + return StringRef("b128"); + default: + return failure(); + } + } else { + // CDNA uses dword, dwordx2, etc. + switch (bitWidth) { + case 8: + return StringRef("byte"); + case 16: + return StringRef("short"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } } } @@ -186,16 +229,10 @@ static Value extractBufferDescriptor(IRRewriter &rewriter, Location loc, // Create proper memref descriptor struct type: {ptr, ptr, offset, sizes..., // strides...} auto memrefType = cast(memref.getType()); - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); auto i64Type = rewriter.getI64Type(); - SmallVector descriptorFields{ptrType, ptrType, - i64Type}; // allocated, aligned, offset - // Add sizes and strides for each dimension - for (int64_t i = 0; i < memrefType.getRank(); ++i) - descriptorFields.push_back(i64Type); // size - - for (int64_t i = 0; i < memrefType.getRank(); ++i) - descriptorFields.push_back(i64Type); // stride + auto arrayType = LLVM::LLVMArrayType::get(i64Type, memrefType.getRank()); + Type descriptorFields[] = {ptrType, ptrType, i64Type, arrayType, arrayType}; auto memrefDescType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), descriptorFields); @@ -204,9 +241,20 @@ static Value extractBufferDescriptor(IRRewriter &rewriter, Location loc, UnrealizedConversionCastOp::create(rewriter, loc, memrefDescType, memref) .getResult(0); - // Use MemRefDescriptor to extract aligned pointer MemRefDescriptor memrefDesc(memrefDescVal); - return memrefDesc.alignedPtr(rewriter, loc); + Value result = memrefDesc.alignedPtr(rewriter, loc); + + auto i128Type = IntegerType::get(rewriter.getContext(), 128); + result = LLVM::PtrToIntOp::create(rewriter, loc, i128Type, result); + + // Bitcast to vector<4xi32> for buffer descriptor + auto vec4i32Type = VectorType::get({4}, rewriter.getI32Type()); + result = LLVM::BitcastOp::create(rewriter, loc, vec4i32Type, result); + + // Use readfirstlane to move buffer descriptor to SGPR + result = + ROCDL::ReadfirstlaneOp::create(rewriter, loc, result.getType(), result); + return result; } /// Helper to get memref, result type, and bit width from load operation @@ -241,20 +289,21 @@ static std::tuple getStoreOpInfo(StoreOpTy storeOp) { /// Lower vector/scalar load to AMDGPU buffer load inline assembly template -static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { +static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); if (bitWidth < 32) return success(); - FailureOr suffix = getBufferSuffixLoad(bitWidth); + FailureOr suffix = getBufferSuffixLoad(bitWidth, isRDNAArch); if (failed(suffix)) return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; Location loc = loadOp.getLoc(); rewriter.setInsertionPoint(loadOp); - // Build inline assembly: "buffer_load_dwordx4 $0, $1, $2, 0 offen" + // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" std::string asmStr = ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); @@ -268,7 +317,7 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter) { std::is_same_v ? cast(resultType).getElementTypeBitWidth() : bitWidth; - Value offset = computeMemrefByteOffsetI64( + Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); // Extract buffer descriptor pointer from memref @@ -323,13 +372,14 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { /// Lower vector/scalar store to AMDGPU buffer store inline assembly template -static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { +static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); if (bitWidth < 32) return success(); - FailureOr suffix = getBufferSuffixStore(bitWidth); + FailureOr suffix = getBufferSuffixStore(bitWidth, isRDNAArch); if (failed(suffix)) return storeOp.emitError("unsupported buffer store bit width: ") << bitWidth; @@ -337,7 +387,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { Location loc = storeOp.getLoc(); rewriter.setInsertionPoint(storeOp); - // Build inline assembly: "buffer_store_dwordx4 $0, $1, $2, 0 offen" + // Build inline assembly: "buffer_store_ $0, $1, $2, 0 offen" std::string asmStr = ("buffer_store_" + *suffix + " $0, $1, $2, 0 offen").str(); @@ -351,7 +401,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter) { std::is_same_v ? cast(valueType).getElementTypeBitWidth() : bitWidth; - Value offset = computeMemrefByteOffsetI64( + Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); // Extract buffer descriptor pointer from memref @@ -435,7 +485,7 @@ static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { std::is_same_v ? cast(resultType).getElementTypeBitWidth() : bitWidth; - Value offset = computeMemrefByteOffsetI64( + Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); // DS operations use 32-bit addresses @@ -476,7 +526,7 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { std::is_same_v ? cast(valueType).getElementTypeBitWidth() : bitWidth; - Value offset = computeMemrefByteOffsetI64( + Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); // DS operations use 32-bit addresses @@ -536,6 +586,9 @@ class WaterLowerMemoryOpsPass void runOnOperation() override { IRRewriter rewriter(&getContext()); + // Determine if we're targeting RDNA architecture + bool isRDNAArch = isRDNA(chipset); + // Helper to dispatch to the appropriate lowering function based on address // space auto lowerMemoryOp = [&](Value base, auto lowerBuffer, auto lowerWorkgroup, @@ -551,7 +604,7 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getBase(), - [&]() { return lowerLoadBuffer(loadOp, rewriter); }, + [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter); }, [&]() { return lowerLoadGlobal(loadOp, rewriter); }); if (failed(result)) @@ -561,7 +614,7 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getBase(), - [&]() { return lowerStoreBuffer(storeOp, rewriter); }, + [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter); }, [&]() { return lowerStoreGlobal(storeOp, rewriter); }); if (failed(result)) @@ -571,7 +624,7 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getMemRef(), - [&]() { return lowerLoadBuffer(loadOp, rewriter); }, + [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter); }, [&]() { return lowerLoadGlobal(loadOp, rewriter); }); if (failed(result)) @@ -581,7 +634,7 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getMemRef(), - [&]() { return lowerStoreBuffer(storeOp, rewriter); }, + [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter); }, [&]() { return lowerStoreGlobal(storeOp, rewriter); }); if (failed(result)) From f7a3ab0339a85be94038962ef6a35dd99dbfd5e2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 22:58:33 +0100 Subject: [PATCH 058/114] fix ds Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 19 ++++++------------- water/test/Transforms/lower-memory-ops.mlir | 12 ------------ 2 files changed, 6 insertions(+), 25 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 1e6717161..457eb7cd3 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -488,12 +488,8 @@ static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - // DS operations use 32-bit addresses - Value offset32 = - arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); - - // Create inline assembly operation - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset32}, + // Create inline assembly operation (DS operations use 32-bit addresses) + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); @@ -529,15 +525,12 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - // DS operations use 32-bit addresses - Value offset32 = - arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), offset); - Value valueToStore = storeOp.getValueToStore(); - // Create inline assembly operation (no result for store) - createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{offset32, valueToStore}, asmStr, constraints, + // Create inline assembly operation (no result for store, DS uses 32-bit + // addresses) + createInlineAsm(rewriter, loc, TypeRange{}, ValueRange{offset, valueToStore}, + asmStr, constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index f5359142a..92fc8e246 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -184,7 +184,6 @@ func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<10 // CHECK-LABEL: func.func @ds_load_b32 func.func @ds_load_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<1xf32> { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> return %result : vector<1xf32> @@ -192,7 +191,6 @@ func.func @ds_load_b32(%lds: memref<1024xf32, #gpu.address_space>, %o // CHECK-LABEL: func.func @ds_load_b64 func.func @ds_load_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<2xf32> { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_read_b64 $0, $1", "=v,v" %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> return %result : vector<2xf32> @@ -200,7 +198,6 @@ func.func @ds_load_b64(%lds: memref<1024xf32, #gpu.address_space>, %o // CHECK-LABEL: func.func @ds_load_b96 func.func @ds_load_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<3xf32> { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_read_b96 $0, $1", "=v,v" %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> return %result : vector<3xf32> @@ -208,7 +205,6 @@ func.func @ds_load_b96(%lds: memref<1024xf32, #gpu.address_space>, %o // CHECK-LABEL: func.func @ds_load_b128 func.func @ds_load_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> vector<4xf32> { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" %result = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> return %result : vector<4xf32> @@ -216,7 +212,6 @@ func.func @ds_load_b128(%lds: memref<1024xf32, #gpu.address_space>, % // CHECK-LABEL: func.func @ds_store_b32 func.func @ds_store_b32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<1xf32>) { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<1xf32> return @@ -224,7 +219,6 @@ func.func @ds_store_b32(%lds: memref<1024xf32, #gpu.address_space>, % // CHECK-LABEL: func.func @ds_store_b64 func.func @ds_store_b64(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<2xf32>) { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_write_b64 $0, $1", "v,v" vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<2xf32> return @@ -232,7 +226,6 @@ func.func @ds_store_b64(%lds: memref<1024xf32, #gpu.address_space>, % // CHECK-LABEL: func.func @ds_store_b96 func.func @ds_store_b96(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<3xf32>) { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_write_b96 $0, $1", "v,v" vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<3xf32> return @@ -240,7 +233,6 @@ func.func @ds_store_b96(%lds: memref<1024xf32, #gpu.address_space>, % // CHECK-LABEL: func.func @ds_store_b128 func.func @ds_store_b128(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: vector<4xf32>) { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" vector.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> return @@ -253,12 +245,10 @@ func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> // Store to LDS (should use ds_write) - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_write_b128 $0, $1", "v,v" vector.store %global_data, %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> // Load from LDS (should use ds_read) - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "=v,v" %lds_data = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> @@ -316,7 +306,6 @@ func.func @scalar_store_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_spa // CHECK-LABEL: func.func @scalar_load_ds_f32 func.func @scalar_load_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index) -> f32 { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_read_b32 $0, $1", "=v,v" %result = memref.load %lds[%offset] : memref<1024xf32, #gpu.address_space> return %result : f32 @@ -324,7 +313,6 @@ func.func @scalar_load_ds_f32(%lds: memref<1024xf32, #gpu.address_space>, %offset: index, %data: f32) { - // CHECK: arith.trunci // CHECK: llvm.inline_asm has_side_effects "ds_write_b32 $0, $1", "v,v" memref.store %data, %lds[%offset] : memref<1024xf32, #gpu.address_space> return From 7d8738ff6fed9fc42de1d8550d83e1721b0a8ee8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 23:08:42 +0100 Subject: [PATCH 059/114] buffer stuff Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 69 +++++++++++++------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 457eb7cd3..9845e158c 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -223,9 +223,11 @@ static FailureOr getBufferSuffixStore(unsigned bitWidth, } } -/// Extract buffer descriptor pointer from a fat_raw_buffer memref -static Value extractBufferDescriptor(IRRewriter &rewriter, Location loc, - Value memref) { +/// Extract buffer descriptor and base offset from a fat_raw_buffer memref +/// addrspace(7) format: {<4 x i32> rsrc, i32 offset} (160 bits total) +/// Returns: {resource descriptor (vec<4xi32> in SGPR), base offset (i32)} +static std::pair +extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { // Create proper memref descriptor struct type: {ptr, ptr, offset, sizes..., // strides...} auto memrefType = cast(memref.getType()); @@ -242,19 +244,32 @@ static Value extractBufferDescriptor(IRRewriter &rewriter, Location loc, .getResult(0); MemRefDescriptor memrefDesc(memrefDescVal); - Value result = memrefDesc.alignedPtr(rewriter, loc); + Value bufferPtr = memrefDesc.alignedPtr(rewriter, loc); + // Convert to i160 to access full buffer descriptor {<4 x i32> rsrc, i32 + // offset} + auto i160Type = IntegerType::get(rewriter.getContext(), 160); + Value fullDesc = LLVM::PtrToIntOp::create(rewriter, loc, i160Type, bufferPtr); + + // Extract lower 32 bits for base offset + Value baseOffset = + arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), fullDesc); + + // Extract upper 128 bits for resource descriptor + auto c32 = arith::ConstantIntOp::create(rewriter, loc, i160Type, 32); + Value rsrcBits160 = arith::ShRUIOp::create(rewriter, loc, fullDesc, c32); auto i128Type = IntegerType::get(rewriter.getContext(), 128); - result = LLVM::PtrToIntOp::create(rewriter, loc, i128Type, result); + Value rsrcBits = + arith::TruncIOp::create(rewriter, loc, i128Type, rsrcBits160); - // Bitcast to vector<4xi32> for buffer descriptor + // Bitcast to vector<4xi32> for buffer resource descriptor auto vec4i32Type = VectorType::get({4}, rewriter.getI32Type()); - result = LLVM::BitcastOp::create(rewriter, loc, vec4i32Type, result); + Value rsrc = LLVM::BitcastOp::create(rewriter, loc, vec4i32Type, rsrcBits); + + // Use readfirstlane to move resource descriptor to SGPR + rsrc = ROCDL::ReadfirstlaneOp::create(rewriter, loc, rsrc.getType(), rsrc); - // Use readfirstlane to move buffer descriptor to SGPR - result = - ROCDL::ReadfirstlaneOp::create(rewriter, loc, result.getType(), result); - return result; + return {rsrc, baseOffset}; } /// Helper to get memref, result type, and bit width from load operation @@ -311,8 +326,7 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, // descriptor (SGPR[4]) StringRef constraints = "=v,v,s"; - // Compute byte offset as i64 (not full address, since buffer descriptor has - // base) + // Compute byte offset from indices unsigned elementBitWidth = std::is_same_v ? cast(resultType).getElementTypeBitWidth() @@ -320,13 +334,18 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - // Extract buffer descriptor pointer from memref - Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); + // Extract buffer descriptor and base offset from memref + auto [bufferDesc, baseOffset] = + extractBufferDescriptor(rewriter, loc, memref); + + // Add base offset to computed offset + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); // Create inline assembly operation with result type directly - auto asmOp = - createInlineAsm(rewriter, loc, resultType, ValueRange{offset, bufferDesc}, - asmStr, constraints, /*hasSideEffects=*/true); + auto asmOp = createInlineAsm(rewriter, loc, resultType, + ValueRange{finalOffset, bufferDesc}, asmStr, + constraints, /*hasSideEffects=*/true); rewriter.replaceOp(loadOp, asmOp.getResult(0)); return success(); @@ -395,8 +414,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, // (SGPR[4]) StringRef constraints = "v,v,s"; - // Compute byte offset as i64 (not full address, since buffer descriptor has - // base) + // Compute byte offset from indices unsigned elementBitWidth = std::is_same_v ? cast(valueType).getElementTypeBitWidth() @@ -404,14 +422,19 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - // Extract buffer descriptor pointer from memref - Value bufferDesc = extractBufferDescriptor(rewriter, loc, memref); + // Extract buffer descriptor and base offset from memref + auto [bufferDesc, baseOffset] = + extractBufferDescriptor(rewriter, loc, memref); + + // Add base offset to computed offset + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); Value valueToStore = storeOp.getValueToStore(); // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{valueToStore, offset, bufferDesc}, asmStr, + ValueRange{valueToStore, finalOffset, bufferDesc}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); From 0944a7af1b57f79283f63dfb179a13efc0efdf30 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 12 Dec 2025 23:29:03 +0100 Subject: [PATCH 060/114] buffer fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 9845e158c..b2fff246c 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -232,6 +232,7 @@ extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { // strides...} auto memrefType = cast(memref.getType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); + auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); auto arrayType = LLVM::LLVMArrayType::get(i64Type, memrefType.getRank()); Type descriptorFields[] = {ptrType, ptrType, i64Type, arrayType, arrayType}; @@ -245,6 +246,8 @@ extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { MemRefDescriptor memrefDesc(memrefDescVal); Value bufferPtr = memrefDesc.alignedPtr(rewriter, loc); + Value bufferOffset = memrefDesc.offset(rewriter, loc); + bufferOffset = arith::TruncIOp::create(rewriter, loc, i32Type, bufferOffset); // Convert to i160 to access full buffer descriptor {<4 x i32> rsrc, i32 // offset} @@ -252,8 +255,10 @@ extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { Value fullDesc = LLVM::PtrToIntOp::create(rewriter, loc, i160Type, bufferPtr); // Extract lower 32 bits for base offset - Value baseOffset = - arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), fullDesc); + Value baseOffset = arith::TruncIOp::create(rewriter, loc, i32Type, fullDesc); + + baseOffset = arith::AddIOp::create(rewriter, loc, baseOffset, bufferOffset, + arith::IntegerOverflowFlags::nsw); // Extract upper 128 bits for resource descriptor auto c32 = arith::ConstantIntOp::create(rewriter, loc, i160Type, 32); @@ -262,14 +267,7 @@ extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { Value rsrcBits = arith::TruncIOp::create(rewriter, loc, i128Type, rsrcBits160); - // Bitcast to vector<4xi32> for buffer resource descriptor - auto vec4i32Type = VectorType::get({4}, rewriter.getI32Type()); - Value rsrc = LLVM::BitcastOp::create(rewriter, loc, vec4i32Type, rsrcBits); - - // Use readfirstlane to move resource descriptor to SGPR - rsrc = ROCDL::ReadfirstlaneOp::create(rewriter, loc, rsrc.getType(), rsrc); - - return {rsrc, baseOffset}; + return {rsrcBits, baseOffset}; } /// Helper to get memref, result type, and bit width from load operation From db89a400d232109ac72e0c51137c401b7882b00a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 21:56:12 +0100 Subject: [PATCH 061/114] doc Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index b2fff246c..d1329bb91 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -225,11 +225,11 @@ static FailureOr getBufferSuffixStore(unsigned bitWidth, /// Extract buffer descriptor and base offset from a fat_raw_buffer memref /// addrspace(7) format: {<4 x i32> rsrc, i32 offset} (160 bits total) -/// Returns: {resource descriptor (vec<4xi32> in SGPR), base offset (i32)} +/// Returns: {resource descriptor (i128), base offset (i32)} static std::pair extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { - // Create proper memref descriptor struct type: {ptr, ptr, offset, sizes..., - // strides...} + // Create proper memref descriptor struct type: {ptr, ptr, offset, sizes[rank], + // strides[rank]} auto memrefType = cast(memref.getType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); auto i32Type = rewriter.getI32Type(); From 70839198c4d0e2e2a7e84b62da59bbc9e1b9d708 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 22:03:24 +0100 Subject: [PATCH 062/114] reg2mem pass init Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 20 ++++ water/lib/Transforms/CMakeLists.txt | 1 + .../Transforms/WaterMaterializeRegCopy.cpp | 99 +++++++++++++++++++ .../test/Transforms/materialize-reg-copy.mlir | 61 ++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 water/lib/Transforms/WaterMaterializeRegCopy.cpp create mode 100644 water/test/Transforms/materialize-reg-copy.mlir diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index f8d30313b..961ec60a9 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -198,4 +198,24 @@ def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { ]; } +def WaterMaterializeRegCopy : Pass<"water-materialize-reg-copy"> { + let summary = "Materialize register copies for memref loads"; + let description = [{ + This pass materializes explicit register copies by transforming memref.load + operations to route through a temporary buffer in the virtual register + memory space (memspace 128). For each load: + 1. Creates a subview of the source memref at the load indices + 2. Allocates a temporary buffer in memory space 128 (virtual register space) + 3. Copies from the subview to the temporary register buffer + 4. Loads from the temporary register buffer + + This transformation makes register traffic explicit in the IR, enabling + better analysis and optimization of register usage patterns. + }]; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::memref::MemRefDialect", + ]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index 900539ad5..ced8d9a44 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRWaterTransforms SLPVectorizer.cpp WaterInsertWaitcnt.cpp WaterLowerMemoryOps.cpp + WaterMaterializeRegCopy.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/water diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp new file mode 100644 index 000000000..54e05323d --- /dev/null +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -0,0 +1,99 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERMATERIALIZEREGCOPY +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Materialize register copies by routing memref.load through temporary +/// buffers in virtual register space (memspace 128). +class WaterMaterializeRegCopyPass + : public water::impl::WaterMaterializeRegCopyBase< + WaterMaterializeRegCopyPass> { +public: + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + + // Collect all load operations to transform + SmallVector loadsToTransform; + getOperation()->walk( + [&](memref::LoadOp loadOp) { loadsToTransform.push_back(loadOp); }); + + for (memref::LoadOp loadOp : loadsToTransform) { + if (failed(materializeRegCopy(rewriter, loadOp))) + return signalPassFailure(); + } + } + +private: + /// Transform a single load operation to use register space copy. + LogicalResult materializeRegCopy(IRRewriter &rewriter, + memref::LoadOp loadOp) { + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Get the source memref and indices + Value memref = loadOp.getMemRef(); + ValueRange indices = loadOp.getIndices(); + auto memrefType = cast(memref.getType()); + Type elementType = memrefType.getElementType(); + + // Create constants for subview + SmallVector offsets, sizes, strides; + for (Value index : indices) { + offsets.push_back(index); + sizes.push_back(rewriter.getIndexAttr(1)); + strides.push_back(rewriter.getIndexAttr(1)); + } + + // Create subview of size [1, 1, ..., 1] at the load indices + auto subviewType = + memref::SubViewOp::inferResultType(memrefType, offsets, sizes, strides); + auto subviewMemRefType = cast(subviewType); + Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, + memref, offsets, sizes, strides); + + // Create temporary buffer in virtual register space (memspace 128) + auto regMemSpace = rewriter.getI32IntegerAttr(128); + auto tempType = + MemRefType::get(subviewMemRefType.getShape(), elementType, + /*layout=*/MemRefLayoutAttrInterface{}, regMemSpace); + Value tempAlloca = memref::AllocaOp::create(rewriter, loc, tempType, + /*dynamicSizes=*/ValueRange{}, + /*alignment=*/IntegerAttr()); + + // Copy from subview to temp register buffer + memref::CopyOp::create(rewriter, loc, subview, tempAlloca); + + // Create zero indices for loading from temp buffer + SmallVector zeroIndices; + for (unsigned i = 0; i < indices.size(); ++i) + zeroIndices.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); + + // Load from the temporary register buffer + Value result = + memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + + // Replace the original load with the new one + rewriter.replaceOp(loadOp, result); + + return success(); + } +}; + +} // namespace diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir new file mode 100644 index 000000000..f8fdd15a0 --- /dev/null +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -0,0 +1,61 @@ +// RUN: water-opt %s --water-materialize-reg-copy | FileCheck %s + +// CHECK-LABEL: func @test_simple_load +func.func @test_simple_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> f32 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] + // CHECK-SAME: memref<10x20xf32> to memref<1x1xf32, strided<[20, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0_1]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i, %j] : memref<10x20xf32> + return %0 : f32 +} + +// CHECK-LABEL: func @test_1d_load +func.func @test_1d_load(%arg0: memref<100xf16>, %i: index) -> f16 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1] [1] [1] + // CHECK-SAME: memref<100xf16> to memref<1xf16, strided<[1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf16, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i] : memref<100xf16> + return %0 : f16 +} + +// CHECK-LABEL: func @test_3d_load +func.func @test_3d_load(%arg0: memref<8x16x32xi32>, %i: index, %j: index, %k: index) -> i32 { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 1, 1] [1, 1, 1] + // CHECK-SAME: memref<8x16x32xi32> to memref<1x1x1xi32, strided<[512, 32, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1x1xi32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[C0_2:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0_1]], %[[C0_2]]] + // CHECK: return %[[RESULT]] + %0 = memref.load %arg0[%i, %j, %k] : memref<8x16x32xi32> + return %0 : i32 +} + +// CHECK-LABEL: func @test_multiple_loads +func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) -> f32 { + // CHECK: memref.subview + // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy + // CHECK: memref.load + %0 = memref.load %arg0[%i, %j] : memref<10x10xf32> + + // CHECK: memref.subview + // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> + // CHECK: memref.copy + // CHECK: memref.load + %1 = memref.load %arg0[%j, %i] : memref<10x10xf32> + + %2 = arith.addf %0, %1 : f32 + return %2 : f32 +} From f6bb568c32a81cbe776b74670445dd7e0793c0d0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 22:03:42 +0100 Subject: [PATCH 063/114] doc Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index d1329bb91..67915a3d8 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -228,8 +228,8 @@ static FailureOr getBufferSuffixStore(unsigned bitWidth, /// Returns: {resource descriptor (i128), base offset (i32)} static std::pair extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { - // Create proper memref descriptor struct type: {ptr, ptr, offset, sizes[rank], - // strides[rank]} + // Create proper memref descriptor struct type: {ptr, ptr, offset, + // sizes[rank], strides[rank]} auto memrefType = cast(memref.getType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 7); auto i32Type = rewriter.getI32Type(); From 17a1bf3f792315e7388fe2ead37a316ccd87707c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 22:08:16 +0100 Subject: [PATCH 064/114] skip existing Signed-off-by: Ivan Butygin --- .../Transforms/WaterMaterializeRegCopy.cpp | 12 +++++++++-- .../test/Transforms/materialize-reg-copy.mlir | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp index 54e05323d..2b50a44a6 100644 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -31,8 +31,16 @@ class WaterMaterializeRegCopyPass // Collect all load operations to transform SmallVector loadsToTransform; - getOperation()->walk( - [&](memref::LoadOp loadOp) { loadsToTransform.push_back(loadOp); }); + getOperation()->walk([&](memref::LoadOp loadOp) { + auto memrefType = cast(loadOp.getMemRef().getType()); + // Skip loads already from virtual register space (memspace 128) + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) { + if (memSpace.getInt() == 128) + return; + } + loadsToTransform.push_back(loadOp); + }); for (memref::LoadOp loadOp : loadsToTransform) { if (failed(materializeRegCopy(rewriter, loadOp))) diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir index f8fdd15a0..96d723929 100644 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -59,3 +59,23 @@ func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) - %2 = arith.addf %0, %1 : f32 return %2 : f32 } + +// CHECK-LABEL: func @test_skip_memspace_128 +func.func @test_skip_memspace_128(%arg0: memref<10xf32>, %arg1: memref<5xf32, 128 : i32>, %i: index) -> f32 { + // This load should be transformed (from default memspace) + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[VAL0:.*]] = memref.load %[[TEMP]][%[[C0]]] + %0 = memref.load %arg0[%i] : memref<10xf32> + + // This load should NOT be transformed (already from memspace 128) + // CHECK: %[[VAL1:.*]] = memref.load %arg1[%arg2] : memref<5xf32, 128 : i32> + %1 = memref.load %arg1[%i] : memref<5xf32, 128 : i32> + + // CHECK: %[[RESULT:.*]] = arith.addf %[[VAL0]], %[[VAL1]] + %result = arith.addf %0, %1 : f32 + // CHECK: return %[[RESULT]] + return %result : f32 +} From c5c09ff4a0b98df418517f421518c95109ca0b0c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 22:38:05 +0100 Subject: [PATCH 065/114] block tests Signed-off-by: Ivan Butygin --- .../Transforms/WaterMaterializeRegCopy.cpp | 35 ++++++++++-- .../test/Transforms/materialize-reg-copy.mlir | 56 +++++++++++++++++-- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp index 2b50a44a6..ed5a440b3 100644 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -88,17 +88,40 @@ class WaterMaterializeRegCopyPass // Copy from subview to temp register buffer memref::CopyOp::create(rewriter, loc, subview, tempAlloca); - // Create zero indices for loading from temp buffer + // Group uses by block and find the first use in each block + Value loadResult = loadOp.getResult(); + DenseMap blockToFirstUse; + + for (OpOperand &use : loadResult.getUses()) { + Operation *userOp = use.getOwner(); + Block *userBlock = userOp->getBlock(); + + auto it = blockToFirstUse.find(userBlock); + if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) + blockToFirstUse[userBlock] = userOp; + } + + // Create one load per block, right before the first use in that block + DenseMap blockToLoad; SmallVector zeroIndices; for (unsigned i = 0; i < indices.size(); ++i) zeroIndices.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); - // Load from the temporary register buffer - Value result = - memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + for (auto &[block, firstUse] : blockToFirstUse) { + rewriter.setInsertionPoint(firstUse); + Value load = + memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + blockToLoad[block] = load; + } + + // Replace uses with the appropriate load for their block + for (OpOperand &use : llvm::make_early_inc_range(loadResult.getUses())) { + Block *userBlock = use.getOwner()->getBlock(); + use.set(blockToLoad[userBlock]); + } - // Replace the original load with the new one - rewriter.replaceOp(loadOp, result); + // Erase the original load + rewriter.eraseOp(loadOp); return success(); } diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir index 96d723929..eb5286178 100644 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -44,18 +44,22 @@ func.func @test_3d_load(%arg0: memref<8x16x32xi32>, %i: index, %j: index, %k: in // CHECK-LABEL: func @test_multiple_loads func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) -> f32 { + // First load: subview, alloca, copy // CHECK: memref.subview // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> // CHECK: memref.copy - // CHECK: memref.load %0 = memref.load %arg0[%i, %j] : memref<10x10xf32> + // Second load: subview, alloca, copy // CHECK: memref.subview // CHECK: memref.alloca() : memref<1x1xf32, 128 : i32> // CHECK: memref.copy - // CHECK: memref.load %1 = memref.load %arg0[%j, %i] : memref<10x10xf32> + // Now the actual loads happen right before the addf (late as possible) + // CHECK: memref.load + // CHECK: memref.load + // CHECK: arith.addf %2 = arith.addf %0, %1 : f32 return %2 : f32 } @@ -63,19 +67,63 @@ func.func @test_multiple_loads(%arg0: memref<10x10xf32>, %i: index, %j: index) - // CHECK-LABEL: func @test_skip_memspace_128 func.func @test_skip_memspace_128(%arg0: memref<10xf32>, %arg1: memref<5xf32, 128 : i32>, %i: index) -> f32 { // This load should be transformed (from default memspace) + // First: subview, alloca, copy // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[VAL0:.*]] = memref.load %[[TEMP]][%[[C0]]] %0 = memref.load %arg0[%i] : memref<10xf32> // This load should NOT be transformed (already from memspace 128) + // It stays in place // CHECK: %[[VAL1:.*]] = memref.load %arg1[%arg2] : memref<5xf32, 128 : i32> %1 = memref.load %arg1[%i] : memref<5xf32, 128 : i32> - // CHECK: %[[RESULT:.*]] = arith.addf %[[VAL0]], %[[VAL1]] + // The load from temp happens late (right before addf) + // CHECK: %[[VAL0:.*]] = memref.load %[[TEMP]][%[[C0]]] + // Note: operands may be reordered + // CHECK: arith.addf %result = arith.addf %0, %1 : f32 + // CHECK: return + return %result : f32 +} + +// CHECK-LABEL: func @test_control_flow +func.func @test_control_flow(%arg0: memref<10xf32>, %cond: i1, %i: index) -> f32 { + // Load happens once, but value is used in multiple blocks + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg2] [1] [1] + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %val = memref.load %arg0[%i] : memref<10xf32> + + // CHECK: cf.cond_br + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // First block: load happens here before the addf + // CHECK: ^bb1: + // CHECK: %[[CONST1:.*]] = arith.constant 1.0 + // CHECK: %[[LOAD1:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: %[[ADD1:.*]] = arith.addf %[[LOAD1]], %[[CONST1]] + %c1 = arith.constant 1.0 : f32 + %sum1 = arith.addf %val, %c1 : f32 + // CHECK: cf.br ^bb3(%[[ADD1]] + cf.br ^bb3(%sum1 : f32) + +^bb2: + // Second block: another load happens here before the mulf + // CHECK: ^bb2: + // CHECK: %[[CONST2:.*]] = arith.constant 2.0 + // CHECK: %[[LOAD2:.*]] = memref.load %[[TEMP]][%[[C0]]] + // CHECK: %[[MUL:.*]] = arith.mulf %[[LOAD2]], %[[CONST2]] + %c2 = arith.constant 2.0 : f32 + %prod = arith.mulf %val, %c2 : f32 + // CHECK: cf.br ^bb3(%[[MUL]] + cf.br ^bb3(%prod : f32) + +^bb3(%result: f32): + // CHECK: ^bb3(%[[RESULT:.*]]: f32): // CHECK: return %[[RESULT]] return %result : f32 } From ef0abf9dbd81bead0e9bd15b4d21e65d38a6d2bd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 23:01:36 +0100 Subject: [PATCH 066/114] vector op support Signed-off-by: Ivan Butygin --- .../Transforms/WaterMaterializeRegCopy.cpp | 97 ++++++++++++------- .../test/Transforms/materialize-reg-copy.mlir | 14 +++ 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp index ed5a440b3..69779c401 100644 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -30,43 +31,70 @@ class WaterMaterializeRegCopyPass IRRewriter rewriter(&getContext()); // Collect all load operations to transform - SmallVector loadsToTransform; - getOperation()->walk([&](memref::LoadOp loadOp) { - auto memrefType = cast(loadOp.getMemRef().getType()); - // Skip loads already from virtual register space (memspace 128) - if (auto memSpace = - dyn_cast_or_null(memrefType.getMemorySpace())) { - if (memSpace.getInt() == 128) - return; + SmallVector loadsToTransform; + getOperation()->walk([&](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + if (!isInRegisterSpace(cast(loadOp.getMemRef().getType()))) + loadsToTransform.push_back(op); + } else if (auto loadOp = dyn_cast(op)) { + if (!isInRegisterSpace(cast(loadOp.getBase().getType()))) + loadsToTransform.push_back(op); } - loadsToTransform.push_back(loadOp); }); - for (memref::LoadOp loadOp : loadsToTransform) { - if (failed(materializeRegCopy(rewriter, loadOp))) + for (Operation *op : loadsToTransform) { + if (failed(materializeRegCopy(rewriter, op))) return signalPassFailure(); } } private: + /// Check if a memref type is in virtual register space (memspace 128). + static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; + } + /// Transform a single load operation to use register space copy. - LogicalResult materializeRegCopy(IRRewriter &rewriter, - memref::LoadOp loadOp) { - Location loc = loadOp.getLoc(); - rewriter.setInsertionPoint(loadOp); - - // Get the source memref and indices - Value memref = loadOp.getMemRef(); - ValueRange indices = loadOp.getIndices(); + LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { + Location loc = op->getLoc(); + rewriter.setInsertionPoint(op); + + // Extract memref, indices, and element type from either load type + Value memref, loadResult; + ValueRange indices; + Type elementType; + SmallVector loadShape; + + if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getMemRef(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + elementType = loadOp.getType(); + loadShape.resize(indices.size(), 1); + } else if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getBase(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + VectorType vecType = loadOp.getVectorType(); + elementType = vecType.getElementType(); + loadShape.resize(indices.size() - vecType.getRank(), 1); + llvm::append_range(loadShape, vecType.getShape()); + } else { + return op->emitError("unsupported load operation"); + } + auto memrefType = cast(memref.getType()); - Type elementType = memrefType.getElementType(); - // Create constants for subview + // Create subview parameters + Attribute one = rewriter.getIndexAttr(1); SmallVector offsets, sizes, strides; - for (Value index : indices) { + for (auto [index, shape] : llvm::zip(indices, loadShape)) { offsets.push_back(index); - sizes.push_back(rewriter.getIndexAttr(1)); - strides.push_back(rewriter.getIndexAttr(1)); + sizes.push_back(rewriter.getIndexAttr(shape)); + strides.push_back(one); } // Create subview of size [1, 1, ..., 1] at the load indices @@ -89,28 +117,30 @@ class WaterMaterializeRegCopyPass memref::CopyOp::create(rewriter, loc, subview, tempAlloca); // Group uses by block and find the first use in each block - Value loadResult = loadOp.getResult(); DenseMap blockToFirstUse; - for (OpOperand &use : loadResult.getUses()) { Operation *userOp = use.getOwner(); Block *userBlock = userOp->getBlock(); - auto it = blockToFirstUse.find(userBlock); if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) blockToFirstUse[userBlock] = userOp; } - // Create one load per block, right before the first use in that block - DenseMap blockToLoad; + // Create zero indices for loading from temp buffer SmallVector zeroIndices; - for (unsigned i = 0; i < indices.size(); ++i) + for (size_t i = 0; i < loadShape.size(); ++i) zeroIndices.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); + // Create one load per block, right before the first use in that block + DenseMap blockToLoad; for (auto &[block, firstUse] : blockToFirstUse) { rewriter.setInsertionPoint(firstUse); - Value load = - memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + Value load; + if (isa(op)) + load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + else if (auto vecLoadOp = dyn_cast(op)) + load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), + tempAlloca, zeroIndices); blockToLoad[block] = load; } @@ -121,8 +151,7 @@ class WaterMaterializeRegCopyPass } // Erase the original load - rewriter.eraseOp(loadOp); - + rewriter.eraseOp(op); return success(); } }; diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir index eb5286178..52c278b9e 100644 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -14,6 +14,20 @@ func.func @test_simple_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> f return %0 : f32 } +// CHECK-LABEL: func @test_simple_vector_load +func.func @test_simple_vector_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> vector<4xf32> { + // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1, %arg2] [1, 4] [1, 1] + // CHECK-SAME: memref<10x20xf32> to memref<1x4xf32, strided<[20, 1], offset: ?>> + // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x4xf32, 128 : i32> + // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = vector.load %[[TEMP]][%[[C0]], %[[C0_1]]] + // CHECK: return %[[RESULT]] + %0 = vector.load %arg0[%i, %j] : memref<10x20xf32>, vector<4xf32> + return %0 : vector<4xf32> +} + // CHECK-LABEL: func @test_1d_load func.func @test_1d_load(%arg0: memref<100xf16>, %i: index) -> f16 { // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[%arg1] [1] [1] From bef1c6f770b9f834207916a0a12b70358e3e154e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Dec 2025 23:07:31 +0100 Subject: [PATCH 067/114] propagate views Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 24 ++++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 499d99035..6c79a4c6e 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -30,17 +30,25 @@ namespace mlir::water { namespace { +/// Try to propagate view operations to the base memref. +static Value propagateViewOps(Value value) { + while (auto view = value.getDefiningOp()) + value = view.getViewSource(); + + return value; +} + /// Check if the operation is a load operation and return the base memref. static std::optional isLoadOp(Operation *op) { // TODO: replace with the interface when available. if (auto load = dyn_cast(op)) - return load.getBase(); + return propagateViewOps(load.getBase()); if (auto load = dyn_cast(op)) - return load.getMemRef(); + return propagateViewOps(load.getMemRef()); if (auto copy = dyn_cast(op)) - return copy.getSource(); + return propagateViewOps(copy.getSource()); if (auto gather = dyn_cast(op)) - return gather.getSrc(); + return propagateViewOps(gather.getSrc()); return std::nullopt; } @@ -49,13 +57,13 @@ static std::optional isLoadOp(Operation *op) { static std::optional isStoreOp(Operation *op) { // TODO: replace with the interface when available. if (auto store = dyn_cast(op)) - return store.getBase(); + return propagateViewOps(store.getBase()); if (auto store = dyn_cast(op)) - return store.getMemRef(); + return propagateViewOps(store.getMemRef()); if (auto copy = dyn_cast(op)) - return copy.getTarget(); + return propagateViewOps(copy.getTarget()); if (auto gather = dyn_cast(op)) - return gather.getDst(); + return propagateViewOps(gather.getDst()); return std::nullopt; } From 5423479629ea9c1e23c77667160a383fb5a0e3b6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 00:57:40 +0100 Subject: [PATCH 068/114] update desc Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 961ec60a9..4503f72b5 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -199,9 +199,9 @@ def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { } def WaterMaterializeRegCopy : Pass<"water-materialize-reg-copy"> { - let summary = "Materialize register copies for memref loads"; + let summary = "Materialize register copies for loads"; let description = [{ - This pass materializes explicit register copies by transforming memref.load + This pass materializes explicit register copies by transforming load operations to route through a temporary buffer in the virtual register memory space (memspace 128). For each load: 1. Creates a subview of the source memref at the load indices From d47022be5d5843b6ff74faf513cf24053522f55e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 02:01:14 +0100 Subject: [PATCH 069/114] loop stuf Signed-off-by: Ivan Butygin --- .../Transforms/WaterMaterializeRegCopy.cpp | 100 +++++++++++++++++- .../test/Transforms/materialize-reg-copy.mlir | 10 +- 2 files changed, 100 insertions(+), 10 deletions(-) diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp index 69779c401..8462b093a 100644 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -8,7 +8,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -46,6 +48,10 @@ class WaterMaterializeRegCopyPass if (failed(materializeRegCopy(rewriter, op))) return signalPassFailure(); } + + // Hoist allocas out of loops when their loads are yielded + getOperation()->walk( + [&](scf::ForOp forOp) { (void)hoistAllocasFromLoop(rewriter, forOp); }); } private: @@ -127,9 +133,8 @@ class WaterMaterializeRegCopyPass } // Create zero indices for loading from temp buffer - SmallVector zeroIndices; - for (size_t i = 0; i < loadShape.size(); ++i) - zeroIndices.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); + SmallVector zeroIndices( + loadShape.size(), arith::ConstantIndexOp::create(rewriter, loc, 0)); // Create one load per block, right before the first use in that block DenseMap blockToLoad; @@ -154,6 +159,95 @@ class WaterMaterializeRegCopyPass rewriter.eraseOp(op); return success(); } + + /// Hoist allocas from loops when their loads are yielded. + void hoistAllocasFromLoop(IRRewriter &rewriter, scf::ForOp loop) { + auto yieldedValues = loop.getYieldedValuesMutable(); + if (!yieldedValues) + return; + + auto loopResults = loop.getLoopResults(); + if (!loopResults) + return; + + auto loopInits = loop.getInitsMutable(); + + Block *body = loop.getBody(); + Location loc = loop.getLoc(); + + DominanceInfo dom; + + // Find yielded values that come from loads of memspace 128 allocas + for (auto [idx, yieldedValue, iterArg, init, result] : + llvm::enumerate(*yieldedValues, loop.getRegionIterArgs(), loopInits, + *loopResults)) { + // Check if this is a load from memspace 128 + Operation *defOp = yieldedValue.get().getDefiningOp(); + if (!defOp) + continue; + + Value alloca; + ValueRange loadIndices; + if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getMemRef(); + loadIndices = loadOp.getIndices(); + } else if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getBase(); + loadIndices = loadOp.getIndices(); + } else { + continue; + } + + if (!loadIndices.empty()) + continue; + + // Check if loading from memspace 128 alloca defined in this loop + auto allocaOp = alloca.getDefiningOp(); + if (!allocaOp) + continue; + if (!isInRegisterSpace(cast(alloca.getType()))) + continue; + if (!body->findAncestorOpInBlock(*allocaOp)) + continue; + + // If load dominates any use of the iter arg, we can't hoist the alloca + // because the load would be invalidated by the store. + bool dominates = false; + for (Operation *user : iterArg.getUsers()) { + if (dom.dominates(defOp, user)) { + dominates = true; + break; + } + } + if (dominates) + continue; + + // Hoist the alloca before the loop + allocaOp->moveBefore(loop); + rewriter.setInsertionPointAfter(allocaOp); + + // Store the iter arg into the alloca + if (isa(defOp)) { + memref::StoreOp::create(rewriter, loc, init.get(), alloca, loadIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + vector::StoreOp::create(rewriter, loc, init.get(), alloca, loadIndices); + } + + // Create a load after the loop + rewriter.setInsertionPointAfter(loop); + Value loadAfterLoop; + if (isa(defOp)) { + loadAfterLoop = + memref::LoadOp::create(rewriter, loc, alloca, loadIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + loadAfterLoop = vector::LoadOp::create( + rewriter, loc, vectorLoad.getVectorType(), alloca, loadIndices); + } + + // Replace uses of the loop result with the new load + result.replaceAllUsesWith(loadAfterLoop); + } + } }; } // namespace diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir index 52c278b9e..e3b24d018 100644 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -7,8 +7,7 @@ func.func @test_simple_load(%arg0: memref<10x20xf32>, %i: index, %j: index) -> f // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1xf32, 128 : i32> // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0_1]]] + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]]] // CHECK: return %[[RESULT]] %0 = memref.load %arg0[%i, %j] : memref<10x20xf32> return %0 : f32 @@ -21,8 +20,7 @@ func.func @test_simple_vector_load(%arg0: memref<10x20xf32>, %i: index, %j: inde // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x4xf32, 128 : i32> // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = vector.load %[[TEMP]][%[[C0]], %[[C0_1]]] + // CHECK: %[[RESULT:.*]] = vector.load %[[TEMP]][%[[C0]], %[[C0]]] // CHECK: return %[[RESULT]] %0 = vector.load %arg0[%i, %j] : memref<10x20xf32>, vector<4xf32> return %0 : vector<4xf32> @@ -48,9 +46,7 @@ func.func @test_3d_load(%arg0: memref<8x16x32xi32>, %i: index, %j: index, %k: in // CHECK: %[[TEMP:.*]] = memref.alloca() : memref<1x1x1xi32, 128 : i32> // CHECK: memref.copy %[[SUBVIEW]], %[[TEMP]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[C0_2:.*]] = arith.constant 0 : index - // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0_1]], %[[C0_2]]] + // CHECK: %[[RESULT:.*]] = memref.load %[[TEMP]][%[[C0]], %[[C0]], %[[C0]]] // CHECK: return %[[RESULT]] %0 = memref.load %arg0[%i, %j, %k] : memref<8x16x32xi32> return %0 : i32 From 4a4d13bd45703909835f714dbb30b1cef38463cd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 02:24:14 +0100 Subject: [PATCH 070/114] loop update Signed-off-by: Ivan Butygin --- .../Transforms/WaterMaterializeRegCopy.cpp | 91 +++++++++++-------- .../test/Transforms/materialize-reg-copy.mlir | 19 ++++ 2 files changed, 73 insertions(+), 37 deletions(-) diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp index 8462b093a..3cf0a317a 100644 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -63,6 +63,45 @@ class WaterMaterializeRegCopyPass return false; } + static SmallVector getZeroIndices(IRRewriter &rewriter, Location loc, + unsigned rank) { + return {rank, arith::ConstantIndexOp::create(rewriter, loc, 0)}; + } + + static void createLoads(IRRewriter &rewriter, Location loc, Value value, + unsigned rank, Value tempAlloca, Operation *op) { + // Group uses by block and find the first use in each block + DenseMap blockToFirstUse; + for (OpOperand &use : value.getUses()) { + Operation *userOp = use.getOwner(); + Block *userBlock = userOp->getBlock(); + auto it = blockToFirstUse.find(userBlock); + if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) + blockToFirstUse[userBlock] = userOp; + } + + SmallVector zeroIndices = getZeroIndices(rewriter, loc, rank); + + // Create one load per block, right before the first use in that block + DenseMap blockToLoad; + for (auto &[block, firstUse] : blockToFirstUse) { + rewriter.setInsertionPoint(firstUse); + Value load; + if (isa(op)) + load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + else if (auto vecLoadOp = dyn_cast(op)) + load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), + tempAlloca, zeroIndices); + blockToLoad[block] = load; + } + + // Replace uses with the appropriate load for their block + for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) { + Block *userBlock = use.getOwner()->getBlock(); + use.set(blockToLoad[userBlock]); + } + } + /// Transform a single load operation to use register space copy. LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { Location loc = op->getLoc(); @@ -122,38 +161,7 @@ class WaterMaterializeRegCopyPass // Copy from subview to temp register buffer memref::CopyOp::create(rewriter, loc, subview, tempAlloca); - // Group uses by block and find the first use in each block - DenseMap blockToFirstUse; - for (OpOperand &use : loadResult.getUses()) { - Operation *userOp = use.getOwner(); - Block *userBlock = userOp->getBlock(); - auto it = blockToFirstUse.find(userBlock); - if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) - blockToFirstUse[userBlock] = userOp; - } - - // Create zero indices for loading from temp buffer - SmallVector zeroIndices( - loadShape.size(), arith::ConstantIndexOp::create(rewriter, loc, 0)); - - // Create one load per block, right before the first use in that block - DenseMap blockToLoad; - for (auto &[block, firstUse] : blockToFirstUse) { - rewriter.setInsertionPoint(firstUse); - Value load; - if (isa(op)) - load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); - else if (auto vecLoadOp = dyn_cast(op)) - load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), - tempAlloca, zeroIndices); - blockToLoad[block] = load; - } - - // Replace uses with the appropriate load for their block - for (OpOperand &use : llvm::make_early_inc_range(loadResult.getUses())) { - Block *userBlock = use.getOwner()->getBlock(); - use.set(blockToLoad[userBlock]); - } + createLoads(rewriter, loc, loadResult, loadShape.size(), tempAlloca, op); // Erase the original load rewriter.eraseOp(op); @@ -198,7 +206,9 @@ class WaterMaterializeRegCopyPass continue; } - if (!loadIndices.empty()) + // Check all indices are zero + if (llvm::any_of(loadIndices, + [](Value idx) { return getConstantIntValue(idx) != 0; })) continue; // Check if loading from memspace 128 alloca defined in this loop @@ -226,22 +236,29 @@ class WaterMaterializeRegCopyPass allocaOp->moveBefore(loop); rewriter.setInsertionPointAfter(allocaOp); + SmallVector zeroIndices = + getZeroIndices(rewriter, loc, loadIndices.size()); + // Store the iter arg into the alloca if (isa(defOp)) { - memref::StoreOp::create(rewriter, loc, init.get(), alloca, loadIndices); + memref::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); } else if (auto vectorLoad = dyn_cast(defOp)) { - vector::StoreOp::create(rewriter, loc, init.get(), alloca, loadIndices); + vector::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); } + // Create iter arg loads + createLoads(rewriter, loc, iterArg, loadIndices.size(), alloca, defOp); + // Create a load after the loop rewriter.setInsertionPointAfter(loop); + zeroIndices = getZeroIndices(rewriter, loc, loadIndices.size()); Value loadAfterLoop; if (isa(defOp)) { loadAfterLoop = - memref::LoadOp::create(rewriter, loc, alloca, loadIndices); + memref::LoadOp::create(rewriter, loc, alloca, zeroIndices); } else if (auto vectorLoad = dyn_cast(defOp)) { loadAfterLoop = vector::LoadOp::create( - rewriter, loc, vectorLoad.getVectorType(), alloca, loadIndices); + rewriter, loc, vectorLoad.getVectorType(), alloca, zeroIndices); } // Replace uses of the loop result with the new load diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir index e3b24d018..a76e61b62 100644 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -137,3 +137,22 @@ func.func @test_control_flow(%arg0: memref<10xf32>, %cond: i1, %i: index) -> f32 // CHECK: return %[[RESULT]] return %result : f32 } + +// CHECK-LABEL: func @test_loop_hoist +func.func @test_loop_hoist(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index, %init: f32) -> f32 { + %c0 = arith.constant 0 : index + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<1xf32, 128 : i32> + // CHECK: memref.store %arg4, %[[ALLOCA]][%c0] + // CHECK: scf.for + %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (f32) { + memref.store %arg, %arg0[%c0] : memref<100xf32> + %alloca = memref.alloca() : memref<1xf32, 128 : i32> + %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview, %alloca : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + %val = memref.load %alloca[%c0] : memref<1xf32, 128 : i32> + scf.yield %val : f32 + } + // CHECK: %[[FINAL:.*]] = memref.load %[[ALLOCA]][%c0] + // CHECK: return %[[FINAL]] + return %result : f32 +} From 97809d124d6b2be846fc18f8389ccbc27938d64e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 02:28:22 +0100 Subject: [PATCH 071/114] test Signed-off-by: Ivan Butygin --- water/test/Transforms/materialize-reg-copy.mlir | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/water/test/Transforms/materialize-reg-copy.mlir b/water/test/Transforms/materialize-reg-copy.mlir index a76e61b62..7f789e8c9 100644 --- a/water/test/Transforms/materialize-reg-copy.mlir +++ b/water/test/Transforms/materialize-reg-copy.mlir @@ -142,17 +142,24 @@ func.func @test_control_flow(%arg0: memref<10xf32>, %cond: i1, %i: index) -> f32 func.func @test_loop_hoist(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index, %init: f32) -> f32 { %c0 = arith.constant 0 : index // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<1xf32, 128 : i32> - // CHECK: memref.store %arg4, %[[ALLOCA]][%c0] - // CHECK: scf.for + // CHECK: arith.constant 0 : index + // CHECK: memref.store %arg4, %[[ALLOCA]] + // CHECK: scf.for %[[IV:.*]] = %arg1 to %arg2 step %arg3 iter_args(%[[ITER_ARG:.*]] = %arg4) %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (f32) { + // CHECK: memref.load %[[ALLOCA]] + // CHECK: memref.store %{{.*}}, %arg0[%c0] memref.store %arg, %arg0[%c0] : memref<100xf32> %alloca = memref.alloca() : memref<1xf32, 128 : i32> %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> memref.copy %subview, %alloca : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> %val = memref.load %alloca[%c0] : memref<1xf32, 128 : i32> + // CHECK: memref.subview + // CHECK: memref.copy + // CHECK: memref.load %[[ALLOCA]] + // CHECK: scf.yield scf.yield %val : f32 } - // CHECK: %[[FINAL:.*]] = memref.load %[[ALLOCA]][%c0] - // CHECK: return %[[FINAL]] + // CHECK: memref.load %[[ALLOCA]] + // CHECK: return return %result : f32 } From b6bf5d3125e5e0ddf5440c7fed7152d084029824 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 02:31:39 +0100 Subject: [PATCH 072/114] refac Signed-off-by: Ivan Butygin --- .../Transforms/WaterMaterializeRegCopy.cpp | 424 +++++++++--------- 1 file changed, 211 insertions(+), 213 deletions(-) diff --git a/water/lib/Transforms/WaterMaterializeRegCopy.cpp b/water/lib/Transforms/WaterMaterializeRegCopy.cpp index 3cf0a317a..505fd31a0 100644 --- a/water/lib/Transforms/WaterMaterializeRegCopy.cpp +++ b/water/lib/Transforms/WaterMaterializeRegCopy.cpp @@ -23,6 +23,216 @@ namespace mlir::water { namespace { +/// Check if a memref type is in virtual register space (memspace 128). +static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; +} + +static SmallVector getZeroIndices(IRRewriter &rewriter, Location loc, + unsigned rank) { + return {rank, arith::ConstantIndexOp::create(rewriter, loc, 0)}; +} + +static void createLoads(IRRewriter &rewriter, Location loc, Value value, + unsigned rank, Value tempAlloca, Operation *op) { + // Group uses by block and find the first use in each block + DenseMap blockToFirstUse; + for (OpOperand &use : value.getUses()) { + Operation *userOp = use.getOwner(); + Block *userBlock = userOp->getBlock(); + auto it = blockToFirstUse.find(userBlock); + if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) + blockToFirstUse[userBlock] = userOp; + } + + SmallVector zeroIndices = getZeroIndices(rewriter, loc, rank); + + // Create one load per block, right before the first use in that block + DenseMap blockToLoad; + for (auto &[block, firstUse] : blockToFirstUse) { + rewriter.setInsertionPoint(firstUse); + Value load; + if (isa(op)) + load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); + else if (auto vecLoadOp = dyn_cast(op)) + load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), + tempAlloca, zeroIndices); + blockToLoad[block] = load; + } + + // Replace uses with the appropriate load for their block + for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) { + Block *userBlock = use.getOwner()->getBlock(); + use.set(blockToLoad[userBlock]); + } +} + +/// Transform a single load operation to use register space copy. +static LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { + Location loc = op->getLoc(); + rewriter.setInsertionPoint(op); + + // Extract memref, indices, and element type from either load type + Value memref, loadResult; + ValueRange indices; + Type elementType; + SmallVector loadShape; + + if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getMemRef(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + elementType = loadOp.getType(); + loadShape.resize(indices.size(), 1); + } else if (auto loadOp = dyn_cast(op)) { + memref = loadOp.getBase(); + indices = loadOp.getIndices(); + loadResult = loadOp.getResult(); + VectorType vecType = loadOp.getVectorType(); + elementType = vecType.getElementType(); + loadShape.resize(indices.size() - vecType.getRank(), 1); + llvm::append_range(loadShape, vecType.getShape()); + } else { + return op->emitError("unsupported load operation"); + } + + auto memrefType = cast(memref.getType()); + + // Create subview parameters + Attribute one = rewriter.getIndexAttr(1); + SmallVector offsets, sizes, strides; + for (auto [index, shape] : llvm::zip(indices, loadShape)) { + offsets.push_back(index); + sizes.push_back(rewriter.getIndexAttr(shape)); + strides.push_back(one); + } + + // Create subview of size [1, 1, ..., 1] at the load indices + auto subviewType = + memref::SubViewOp::inferResultType(memrefType, offsets, sizes, strides); + auto subviewMemRefType = cast(subviewType); + Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, + memref, offsets, sizes, strides); + + // Create temporary buffer in virtual register space (memspace 128) + auto regMemSpace = rewriter.getI32IntegerAttr(128); + auto tempType = + MemRefType::get(subviewMemRefType.getShape(), elementType, + /*layout=*/MemRefLayoutAttrInterface{}, regMemSpace); + Value tempAlloca = memref::AllocaOp::create(rewriter, loc, tempType, + /*dynamicSizes=*/ValueRange{}, + /*alignment=*/IntegerAttr()); + + // Copy from subview to temp register buffer + memref::CopyOp::create(rewriter, loc, subview, tempAlloca); + + createLoads(rewriter, loc, loadResult, loadShape.size(), tempAlloca, op); + + // Erase the original load + rewriter.eraseOp(op); + return success(); +} + +/// Hoist allocas from loops when their loads are yielded. +static void hoistAllocasFromLoop(IRRewriter &rewriter, scf::ForOp loop) { + auto yieldedValues = loop.getYieldedValuesMutable(); + if (!yieldedValues) + return; + + auto loopResults = loop.getLoopResults(); + if (!loopResults) + return; + + auto loopInits = loop.getInitsMutable(); + + Block *body = loop.getBody(); + Location loc = loop.getLoc(); + + DominanceInfo dom; + + // Find yielded values that come from loads of memspace 128 allocas + for (auto [idx, yieldedValue, iterArg, init, result] : llvm::enumerate( + *yieldedValues, loop.getRegionIterArgs(), loopInits, *loopResults)) { + // Check if this is a load from memspace 128 + Operation *defOp = yieldedValue.get().getDefiningOp(); + if (!defOp) + continue; + + Value alloca; + ValueRange loadIndices; + if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getMemRef(); + loadIndices = loadOp.getIndices(); + } else if (auto loadOp = dyn_cast(defOp)) { + alloca = loadOp.getBase(); + loadIndices = loadOp.getIndices(); + } else { + continue; + } + + // Check all indices are zero + if (llvm::any_of(loadIndices, + [](Value idx) { return getConstantIntValue(idx) != 0; })) + continue; + + // Check if loading from memspace 128 alloca defined in this loop + auto allocaOp = alloca.getDefiningOp(); + if (!allocaOp) + continue; + if (!isInRegisterSpace(cast(alloca.getType()))) + continue; + if (!body->findAncestorOpInBlock(*allocaOp)) + continue; + + // If load dominates any use of the iter arg, we can't hoist the alloca + // because the load would be invalidated by the store. + bool dominates = false; + for (Operation *user : iterArg.getUsers()) { + if (dom.dominates(defOp, user)) { + dominates = true; + break; + } + } + if (dominates) + continue; + + // Hoist the alloca before the loop + allocaOp->moveBefore(loop); + rewriter.setInsertionPointAfter(allocaOp); + + SmallVector zeroIndices = + getZeroIndices(rewriter, loc, loadIndices.size()); + + // Store the iter arg into the alloca + if (isa(defOp)) { + memref::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + vector::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); + } + + // Create iter arg loads + createLoads(rewriter, loc, iterArg, loadIndices.size(), alloca, defOp); + + // Create a load after the loop + rewriter.setInsertionPointAfter(loop); + zeroIndices = getZeroIndices(rewriter, loc, loadIndices.size()); + Value loadAfterLoop; + if (isa(defOp)) { + loadAfterLoop = + memref::LoadOp::create(rewriter, loc, alloca, zeroIndices); + } else if (auto vectorLoad = dyn_cast(defOp)) { + loadAfterLoop = vector::LoadOp::create( + rewriter, loc, vectorLoad.getVectorType(), alloca, zeroIndices); + } + + // Replace uses of the loop result with the new load + result.replaceAllUsesWith(loadAfterLoop); + } +} + /// Materialize register copies by routing memref.load through temporary /// buffers in virtual register space (memspace 128). class WaterMaterializeRegCopyPass @@ -51,219 +261,7 @@ class WaterMaterializeRegCopyPass // Hoist allocas out of loops when their loads are yielded getOperation()->walk( - [&](scf::ForOp forOp) { (void)hoistAllocasFromLoop(rewriter, forOp); }); - } - -private: - /// Check if a memref type is in virtual register space (memspace 128). - static bool isInRegisterSpace(MemRefType memrefType) { - if (auto memSpace = - dyn_cast_or_null(memrefType.getMemorySpace())) - return memSpace.getInt() == 128; - return false; - } - - static SmallVector getZeroIndices(IRRewriter &rewriter, Location loc, - unsigned rank) { - return {rank, arith::ConstantIndexOp::create(rewriter, loc, 0)}; - } - - static void createLoads(IRRewriter &rewriter, Location loc, Value value, - unsigned rank, Value tempAlloca, Operation *op) { - // Group uses by block and find the first use in each block - DenseMap blockToFirstUse; - for (OpOperand &use : value.getUses()) { - Operation *userOp = use.getOwner(); - Block *userBlock = userOp->getBlock(); - auto it = blockToFirstUse.find(userBlock); - if (it == blockToFirstUse.end() || userOp->isBeforeInBlock(it->second)) - blockToFirstUse[userBlock] = userOp; - } - - SmallVector zeroIndices = getZeroIndices(rewriter, loc, rank); - - // Create one load per block, right before the first use in that block - DenseMap blockToLoad; - for (auto &[block, firstUse] : blockToFirstUse) { - rewriter.setInsertionPoint(firstUse); - Value load; - if (isa(op)) - load = memref::LoadOp::create(rewriter, loc, tempAlloca, zeroIndices); - else if (auto vecLoadOp = dyn_cast(op)) - load = vector::LoadOp::create(rewriter, loc, vecLoadOp.getVectorType(), - tempAlloca, zeroIndices); - blockToLoad[block] = load; - } - - // Replace uses with the appropriate load for their block - for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) { - Block *userBlock = use.getOwner()->getBlock(); - use.set(blockToLoad[userBlock]); - } - } - - /// Transform a single load operation to use register space copy. - LogicalResult materializeRegCopy(IRRewriter &rewriter, Operation *op) { - Location loc = op->getLoc(); - rewriter.setInsertionPoint(op); - - // Extract memref, indices, and element type from either load type - Value memref, loadResult; - ValueRange indices; - Type elementType; - SmallVector loadShape; - - if (auto loadOp = dyn_cast(op)) { - memref = loadOp.getMemRef(); - indices = loadOp.getIndices(); - loadResult = loadOp.getResult(); - elementType = loadOp.getType(); - loadShape.resize(indices.size(), 1); - } else if (auto loadOp = dyn_cast(op)) { - memref = loadOp.getBase(); - indices = loadOp.getIndices(); - loadResult = loadOp.getResult(); - VectorType vecType = loadOp.getVectorType(); - elementType = vecType.getElementType(); - loadShape.resize(indices.size() - vecType.getRank(), 1); - llvm::append_range(loadShape, vecType.getShape()); - } else { - return op->emitError("unsupported load operation"); - } - - auto memrefType = cast(memref.getType()); - - // Create subview parameters - Attribute one = rewriter.getIndexAttr(1); - SmallVector offsets, sizes, strides; - for (auto [index, shape] : llvm::zip(indices, loadShape)) { - offsets.push_back(index); - sizes.push_back(rewriter.getIndexAttr(shape)); - strides.push_back(one); - } - - // Create subview of size [1, 1, ..., 1] at the load indices - auto subviewType = - memref::SubViewOp::inferResultType(memrefType, offsets, sizes, strides); - auto subviewMemRefType = cast(subviewType); - Value subview = memref::SubViewOp::create(rewriter, loc, subviewMemRefType, - memref, offsets, sizes, strides); - - // Create temporary buffer in virtual register space (memspace 128) - auto regMemSpace = rewriter.getI32IntegerAttr(128); - auto tempType = - MemRefType::get(subviewMemRefType.getShape(), elementType, - /*layout=*/MemRefLayoutAttrInterface{}, regMemSpace); - Value tempAlloca = memref::AllocaOp::create(rewriter, loc, tempType, - /*dynamicSizes=*/ValueRange{}, - /*alignment=*/IntegerAttr()); - - // Copy from subview to temp register buffer - memref::CopyOp::create(rewriter, loc, subview, tempAlloca); - - createLoads(rewriter, loc, loadResult, loadShape.size(), tempAlloca, op); - - // Erase the original load - rewriter.eraseOp(op); - return success(); - } - - /// Hoist allocas from loops when their loads are yielded. - void hoistAllocasFromLoop(IRRewriter &rewriter, scf::ForOp loop) { - auto yieldedValues = loop.getYieldedValuesMutable(); - if (!yieldedValues) - return; - - auto loopResults = loop.getLoopResults(); - if (!loopResults) - return; - - auto loopInits = loop.getInitsMutable(); - - Block *body = loop.getBody(); - Location loc = loop.getLoc(); - - DominanceInfo dom; - - // Find yielded values that come from loads of memspace 128 allocas - for (auto [idx, yieldedValue, iterArg, init, result] : - llvm::enumerate(*yieldedValues, loop.getRegionIterArgs(), loopInits, - *loopResults)) { - // Check if this is a load from memspace 128 - Operation *defOp = yieldedValue.get().getDefiningOp(); - if (!defOp) - continue; - - Value alloca; - ValueRange loadIndices; - if (auto loadOp = dyn_cast(defOp)) { - alloca = loadOp.getMemRef(); - loadIndices = loadOp.getIndices(); - } else if (auto loadOp = dyn_cast(defOp)) { - alloca = loadOp.getBase(); - loadIndices = loadOp.getIndices(); - } else { - continue; - } - - // Check all indices are zero - if (llvm::any_of(loadIndices, - [](Value idx) { return getConstantIntValue(idx) != 0; })) - continue; - - // Check if loading from memspace 128 alloca defined in this loop - auto allocaOp = alloca.getDefiningOp(); - if (!allocaOp) - continue; - if (!isInRegisterSpace(cast(alloca.getType()))) - continue; - if (!body->findAncestorOpInBlock(*allocaOp)) - continue; - - // If load dominates any use of the iter arg, we can't hoist the alloca - // because the load would be invalidated by the store. - bool dominates = false; - for (Operation *user : iterArg.getUsers()) { - if (dom.dominates(defOp, user)) { - dominates = true; - break; - } - } - if (dominates) - continue; - - // Hoist the alloca before the loop - allocaOp->moveBefore(loop); - rewriter.setInsertionPointAfter(allocaOp); - - SmallVector zeroIndices = - getZeroIndices(rewriter, loc, loadIndices.size()); - - // Store the iter arg into the alloca - if (isa(defOp)) { - memref::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); - } else if (auto vectorLoad = dyn_cast(defOp)) { - vector::StoreOp::create(rewriter, loc, init.get(), alloca, zeroIndices); - } - - // Create iter arg loads - createLoads(rewriter, loc, iterArg, loadIndices.size(), alloca, defOp); - - // Create a load after the loop - rewriter.setInsertionPointAfter(loop); - zeroIndices = getZeroIndices(rewriter, loc, loadIndices.size()); - Value loadAfterLoop; - if (isa(defOp)) { - loadAfterLoop = - memref::LoadOp::create(rewriter, loc, alloca, zeroIndices); - } else if (auto vectorLoad = dyn_cast(defOp)) { - loadAfterLoop = vector::LoadOp::create( - rewriter, loc, vectorLoad.getVectorType(), alloca, zeroIndices); - } - - // Replace uses of the loop result with the new load - result.replaceAllUsesWith(loadAfterLoop); - } + [&](scf::ForOp forOp) { hoistAllocasFromLoop(rewriter, forOp); }); } }; From af64522392da1cc623c606763294e1ea2dcddca2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 02:35:39 +0100 Subject: [PATCH 073/114] add reg check Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 6c79a4c6e..26898e8d4 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -29,9 +29,17 @@ namespace mlir::water { } // namespace mlir::water namespace { +static bool isRegisterAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getInt() == 128; +} /// Try to propagate view operations to the base memref. -static Value propagateViewOps(Value value) { +static std::optional propagateViewOps(Value value) { + auto memrefType = cast(value.getType()); + if (isRegisterAddressSpace(memrefType)) + return {}; + while (auto view = value.getDefiningOp()) value = view.getViewSource(); From b1eaf1f93fc07c73faae0207e9188e4e71961fc4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 17:51:51 +0100 Subject: [PATCH 074/114] prevent list grow Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 26898e8d4..93d998483 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -108,6 +108,10 @@ struct PendingOperations { : ops(std::move(ops)), opsTokens(std::move(opsTokens)) {} TokenContainer &addOp(Operation *op) { + // Failsafe to prevent infinite list growth. + if (size() >= 256) + llvm::report_fatal_error("Pending operations list is too long"); + ops.push_back(op); auto &back = opsTokens.emplace_back(); if (auto memref = isStoreOp(op)) From a602979bcfe0b72c305e7497dffa8841e8a7caba Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 19:33:56 +0100 Subject: [PATCH 075/114] register space handling in waitcnt insertion Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 59 +++++++++++---------- water/test/Transforms/insert-waitcnt.mlir | 52 ++++++++++++++++++ 2 files changed, 82 insertions(+), 29 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 93d998483..e1866acc9 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -34,12 +34,30 @@ static bool isRegisterAddressSpace(MemRefType type) { return attr && attr.getInt() == 128; } +static bool isWorkgroupAddressSpace(MemRefType type) { + auto attr = dyn_cast_or_null(type.getMemorySpace()); + return attr && attr.getValue() == gpu::AddressSpace::Workgroup; +} + +static bool isWorkgroupAddressSpace(std::optional value) { + if (!value) + return false; + + auto memrefType = cast(value->getType()); + return isWorkgroupAddressSpace(memrefType); +} + +static bool isGlobalAddressSpace(std::optional value) { + if (!value) + return false; + + auto memrefType = cast(value->getType()); + return !isWorkgroupAddressSpace(memrefType) && + !isRegisterAddressSpace(memrefType); +} + /// Try to propagate view operations to the base memref. static std::optional propagateViewOps(Value value) { - auto memrefType = cast(value.getType()); - if (isRegisterAddressSpace(memrefType)) - return {}; - while (auto view = value.getDefiningOp()) value = view.getViewSource(); @@ -76,22 +94,6 @@ static std::optional isStoreOp(Operation *op) { return std::nullopt; } -/// Check if the operation is a load or store operation and return the base -/// memref. -static std::optional isLoadOrStoreOp(Operation *op) { - if (auto store = isStoreOp(op)) - return store; - if (auto load = isLoadOp(op)) - return load; - - return std::nullopt; -} - -static bool isWorkgroupAddressSpace(MemRefType type) { - auto attr = dyn_cast_or_null(type.getMemorySpace()); - return attr && attr.getValue() == gpu::AddressSpace::Workgroup; -} - template static raw_ostream &print_range(raw_ostream &os, T &&range) { llvm::interleaveComma(range, os, [&](const auto &item) { os << item; }); @@ -219,13 +221,14 @@ struct WaitcntRequirement { static WaitcntRequirement getOperationRequirement(Operation *op, bool zero) { WaitcntRequirement req; - if (std::optional base = isLoadOrStoreOp(op)) { - auto memrefType = cast(base->getType()); - if (isWorkgroupAddressSpace(memrefType)) { - req.ds_cnt = zero ? 0 : 1; - } else { - req.load_cnt = zero ? 0 : 1; - } + std::optional loadBase = isLoadOp(op); + std::optional storeBase = isStoreOp(op); + if (isWorkgroupAddressSpace(loadBase) || + isWorkgroupAddressSpace(storeBase)) { + req.ds_cnt = zero ? 0 : 1; + } else if (isGlobalAddressSpace(loadBase) || + isGlobalAddressSpace(storeBase)) { + req.load_cnt = zero ? 0 : 1; } return req; } @@ -452,8 +455,6 @@ class WaitcntState : public AbstractDenseLattice { /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait WaitcntRequirement checkMemoryDependency(Operation *op) const { - - // std::optional currentBase = isLoadOrStoreOp(op); auto checkMemref = [&](Value memref, bool isCurrentLoad, bool isCurrentStore) -> WaitcntRequirement { WaitcntRequirement result; diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 81fcf6777..84f645269 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -510,3 +510,55 @@ func.func @triple_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %ste // CHECK: return return } + + +// CHECK-LABEL: func.func @triple_buffering_reg_space +func.func @triple_buffering_reg_space(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %c0 = arith.constant 0 : index + %buff0 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %buff2 = memref.alloc() : memref<1024xf32, #gpu.address_space> + %reg = memref.alloca() : memref<4xf32, 128 : i32> + + %out = memref.alloc() : memref<1024xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff0 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %buff1 : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // CHECK: scf.for + scf.for %i = %lb to %ub step %step iter_args(%current = %buff0, %next = %buff1, %next_next = %buff2) -> (memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>) { + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.copy + memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> + + // Skip the the prev copy + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: vector.load + %data = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: vector.store + vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: memref.subview + %subview = memref.subview %current[%offset] [4] [1] : memref<1024xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + + // This copy only depends on buffer 2 iterations ago + // CHECK: amdgpu.memory_counter_wait ds(2) + // CHECK: memref.copy + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + + // CHECK-NOT: amdgpu.memory_counter_wait + // CHECK: scf.yield + scf.yield %next, %next_next, %current : memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space>, memref<1024xf32, #gpu.address_space> + } + + // CHECK: return + return +} From c54e97ce9f64cc2e14be24df6f50d7145d6cef0e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 21:16:02 +0100 Subject: [PATCH 076/114] number registers Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 11 ++ water/lib/Transforms/CMakeLists.txt | 1 + water/lib/Transforms/WaterNumberRegisters.cpp | 100 ++++++++++++++++++ .../Transforms/number-registers-error.mlir | 7 ++ water/test/Transforms/number-registers.mlir | 80 ++++++++++++++ 5 files changed, 199 insertions(+) create mode 100644 water/lib/Transforms/WaterNumberRegisters.cpp create mode 100644 water/test/Transforms/number-registers-error.mlir create mode 100644 water/test/Transforms/number-registers.mlir diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 4503f72b5..748bc7e8b 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -218,4 +218,15 @@ def WaterMaterializeRegCopy : Pass<"water-materialize-reg-copy"> { ]; } +def WaterNumberRegisters : InterfacePass<"water-number-registers", "::mlir::FunctionOpInterface"> { + let summary = "Assign physical registers to register space allocas"; + let description = [{ + This pass performs register allocation by assigning physical register numbers + to memref.alloca operations in memory space 128 (virtual register space). + }]; + let dependentDialects = [ + "::mlir::memref::MemRefDialect", + ]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index ced8d9a44..89583aa30 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRWaterTransforms WaterInsertWaitcnt.cpp WaterLowerMemoryOps.cpp WaterMaterializeRegCopy.cpp + WaterNumberRegisters.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/water diff --git a/water/lib/Transforms/WaterNumberRegisters.cpp b/water/lib/Transforms/WaterNumberRegisters.cpp new file mode 100644 index 000000000..dc29251e0 --- /dev/null +++ b/water/lib/Transforms/WaterNumberRegisters.cpp @@ -0,0 +1,100 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERNUMBERREGISTERS +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { + +/// Check if a memref type is in virtual register space (memspace 128). +static bool isInRegisterSpace(MemRefType memrefType) { + if (auto memSpace = + dyn_cast_or_null(memrefType.getMemorySpace())) + return memSpace.getInt() == 128; + return false; +} + +/// Calculate the number of 32-bit registers needed for a memref type. +static FailureOr getRegisterCount(MemRefType memrefType) { + // Calculate total size in bytes + unsigned elementSizeBytes = memrefType.getElementTypeBitWidth() / 8; + unsigned numElements = 1; + for (int64_t dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) + return failure(); // Can't allocate dynamic sizes in registers. + + numElements *= dim; + } + + unsigned totalBytes = elementSizeBytes * numElements; + + // Each register is 32 bits = 4 bytes + // Round up to next register boundary. + return (totalBytes + 3) / 4; +} + +/// Assign physical registers to register space allocas. +class WaterNumberRegistersPass + : public water::impl::WaterNumberRegistersBase { +public: + void runOnOperation() override { + auto func = getOperation(); + MLIRContext *ctx = &getContext(); + + // TODO: for now, just assign registers sequentially. In the future, + // we need a liveness analysis to assign registers. + unsigned nextRegister = 0; + + WalkResult result = func->walk([&](memref::AllocaOp allocaOp) { + auto memrefType = allocaOp.getType(); + if (!isInRegisterSpace(memrefType)) + return WalkResult::advance(); + + auto regCountOr = getRegisterCount(memrefType); + if (failed(regCountOr)) { + allocaOp->emitError( + "Cannot allocate dynamic-sized memref in register space"); + return WalkResult::interrupt(); + } + + unsigned regCount = *regCountOr; + + // Assign starting register number. + allocaOp->setAttr( + "water.register_number", + IntegerAttr::get(IntegerType::get(ctx, 32), nextRegister)); + + // Track how many registers this alloca uses. + allocaOp->setAttr("water.register_count", + IntegerAttr::get(IntegerType::get(ctx, 32), regCount)); + + // Advance to next available register. + nextRegister += regCount; + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) + return signalPassFailure(); + + // Attach metadata to function with total register count. + func->setAttr("water.total_registers", + IntegerAttr::get(IntegerType::get(ctx, 32), nextRegister)); + } +}; + +} // namespace diff --git a/water/test/Transforms/number-registers-error.mlir b/water/test/Transforms/number-registers-error.mlir new file mode 100644 index 000000000..4d9f77435 --- /dev/null +++ b/water/test/Transforms/number-registers-error.mlir @@ -0,0 +1,7 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' --verify-diagnostics + +func.func @test_dynamic_size_error(%n: index) { + // expected-error @+1 {{Cannot allocate dynamic-sized memref in register space}} + %reg = memref.alloca(%n) : memref + return +} diff --git a/water/test/Transforms/number-registers.mlir b/water/test/Transforms/number-registers.mlir new file mode 100644 index 000000000..4e40a63d6 --- /dev/null +++ b/water/test/Transforms/number-registers.mlir @@ -0,0 +1,80 @@ +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' | FileCheck %s + +// CHECK-LABEL: func @test_simple_numbering +// CHECK-SAME: attributes {water.total_registers = 6 : i32} +func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { + %c0 = arith.constant 0 : index + + // 1xf32 = 4 bytes = 1 register, starts at reg 0 + // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 0 : i32} + %reg0 = memref.alloca() : memref<1xf32, 128 : i32> + + // 4xf32 = 16 bytes = 4 registers, starts at reg 1 + // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 1 : i32} + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + + // 1xf32 = 4 bytes = 1 register, starts at reg 5 (after reg1) + // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 5 : i32} + %reg2 = memref.alloca() : memref<1xf32, 128 : i32> + + %subview0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + + %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> + + return %val0 : f32 +} + +// CHECK-LABEL: func @test_loop_with_registers +// CHECK-SAME: attributes {water.total_registers = 1 : i32} +func.func @test_loop_with_registers(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + + // Register allocated outside loop + // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 0 : i32} + %reg = memref.alloca() : memref<1xf32, 128 : i32> + + scf.for %iv = %lb to %ub step %step { + %subview = memref.subview %arg0[%iv] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + memref.store %val, %arg0[%iv] : memref<100xf32> + } + + return +} + +// CHECK-LABEL: func @test_triple_buffering_numbering +// CHECK-SAME: attributes {water.total_registers = 12 : i32} +func.func @test_triple_buffering_numbering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { + %c0 = arith.constant 0 : index + + // Three registers for triple buffering, each 4xf32 = 4 registers + // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 0 : i32} + %reg0 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 4 : i32} + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 8 : i32} + %reg2 = memref.alloca() : memref<4xf32, 128 : i32> + + return +} + +// CHECK-LABEL: func @test_mixed_memspaces +// CHECK-SAME: attributes {water.total_registers = 1 : i32} +func.func @test_mixed_memspaces(%arg0: memref<100xf32>) { + %c0 = arith.constant 0 : index + + // Non-register space alloca - should not be numbered + // CHECK: memref.alloca() : memref<10xf32> + // CHECK-NOT: water.register_number + %local = memref.alloca() : memref<10xf32> + + // Register space alloca - should be numbered + // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 0 : i32} + %reg = memref.alloca() : memref<1xf32, 128 : i32> + + return +} From a5a0c7ee0be88c8c99a32c29f2c135302daae1d3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 21:27:26 +0100 Subject: [PATCH 077/114] make func pass Signed-off-by: Ivan Butygin --- water/include/water/Transforms/Passes.td | 2 +- water/test/Transforms/lower-memory-ops.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 748bc7e8b..0c43078b5 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -174,7 +174,7 @@ def WaterInsertWaitcnt : Pass<"water-insert-waitcnt"> { ]; } -def WaterLowerMemoryOps : Pass<"water-lower-memory-ops"> { +def WaterLowerMemoryOps : InterfacePass<"water-lower-memory-ops", "::mlir::FunctionOpInterface"> { let summary = "Lower high-level memory operations to AMDGPU dialect"; let description = [{ This pass lowers high-level memory operations (vector.load, vector.store, diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 92fc8e246..7dbd51ab4 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -1,4 +1,4 @@ -// RUN: water-opt %s --water-lower-memory-ops | FileCheck %s +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops))' | FileCheck %s // Test lowering of vector memory operations to AMDGPU global_load/store inline assembly From 3715f26562d4eb953e16af963f8fd392b27a3693 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 21:33:14 +0100 Subject: [PATCH 078/114] pass pipline fixes Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index a349d315d..306954e06 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -217,7 +217,7 @@ def make_linear_pass_pipeline( """ def make_pass_arguments( - name: str, args: dict[str, Any], root_op: str | None = None + name: str, args: dict[str, Any], root_op: str | Sequence[str] | None = None ) -> str: ret = ( name @@ -226,7 +226,12 @@ def make_pass_arguments( + "}" ) if root_op: - ret = root_op + "(" + ret + ")" + if isinstance(root_op, str): + ret = root_op + "(" + ret + ")" + elif isinstance(root_op, Sequence): + ret = "(".join(root_op) + "(" + ret + ")" * len(root_op) + else: + raise ValueError(f"Invalid root op: {root_op}") return ret return ( From a3df2b8ad1c846f3777444414c33dd2f3ebf5fed Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 22:00:54 +0100 Subject: [PATCH 079/114] rename regs Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterNumberRegisters.cpp | 6 ++--- water/test/Transforms/number-registers.mlir | 26 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/water/lib/Transforms/WaterNumberRegisters.cpp b/water/lib/Transforms/WaterNumberRegisters.cpp index dc29251e0..176de9358 100644 --- a/water/lib/Transforms/WaterNumberRegisters.cpp +++ b/water/lib/Transforms/WaterNumberRegisters.cpp @@ -75,11 +75,11 @@ class WaterNumberRegistersPass // Assign starting register number. allocaOp->setAttr( - "water.register_number", + "water.vgpr_number", IntegerAttr::get(IntegerType::get(ctx, 32), nextRegister)); // Track how many registers this alloca uses. - allocaOp->setAttr("water.register_count", + allocaOp->setAttr("water.vgpr_count", IntegerAttr::get(IntegerType::get(ctx, 32), regCount)); // Advance to next available register. @@ -92,7 +92,7 @@ class WaterNumberRegistersPass return signalPassFailure(); // Attach metadata to function with total register count. - func->setAttr("water.total_registers", + func->setAttr("water.total_vgprs", IntegerAttr::get(IntegerType::get(ctx, 32), nextRegister)); } }; diff --git a/water/test/Transforms/number-registers.mlir b/water/test/Transforms/number-registers.mlir index 4e40a63d6..44f34cacc 100644 --- a/water/test/Transforms/number-registers.mlir +++ b/water/test/Transforms/number-registers.mlir @@ -1,20 +1,20 @@ // RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' | FileCheck %s // CHECK-LABEL: func @test_simple_numbering -// CHECK-SAME: attributes {water.total_registers = 6 : i32} +// CHECK-SAME: attributes {water.total_vgprs = 6 : i32} func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { %c0 = arith.constant 0 : index // 1xf32 = 4 bytes = 1 register, starts at reg 0 - // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 0 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} %reg0 = memref.alloca() : memref<1xf32, 128 : i32> // 4xf32 = 16 bytes = 4 registers, starts at reg 1 - // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 1 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 1 : i32} %reg1 = memref.alloca() : memref<4xf32, 128 : i32> // 1xf32 = 4 bytes = 1 register, starts at reg 5 (after reg1) - // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 5 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 5 : i32} %reg2 = memref.alloca() : memref<1xf32, 128 : i32> %subview0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> @@ -26,12 +26,12 @@ func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { } // CHECK-LABEL: func @test_loop_with_registers -// CHECK-SAME: attributes {water.total_registers = 1 : i32} +// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} func.func @test_loop_with_registers(%arg0: memref<100xf32>, %lb: index, %ub: index, %step: index) { %c0 = arith.constant 0 : index // Register allocated outside loop - // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 0 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} %reg = memref.alloca() : memref<1xf32, 128 : i32> scf.for %iv = %lb to %ub step %step { @@ -45,35 +45,35 @@ func.func @test_loop_with_registers(%arg0: memref<100xf32>, %lb: index, %ub: ind } // CHECK-LABEL: func @test_triple_buffering_numbering -// CHECK-SAME: attributes {water.total_registers = 12 : i32} +// CHECK-SAME: attributes {water.total_vgprs = 12 : i32} func.func @test_triple_buffering_numbering(%src: memref<1024xf32>, %lb: index, %ub: index, %step: index, %offset: index) { %c0 = arith.constant 0 : index // Three registers for triple buffering, each 4xf32 = 4 registers - // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 0 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 0 : i32} %reg0 = memref.alloca() : memref<4xf32, 128 : i32> - // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 4 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} %reg1 = memref.alloca() : memref<4xf32, 128 : i32> - // CHECK: memref.alloca() {water.register_count = 4 : i32, water.register_number = 8 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 8 : i32} %reg2 = memref.alloca() : memref<4xf32, 128 : i32> return } // CHECK-LABEL: func @test_mixed_memspaces -// CHECK-SAME: attributes {water.total_registers = 1 : i32} +// CHECK-SAME: attributes {water.total_vgprs = 1 : i32} func.func @test_mixed_memspaces(%arg0: memref<100xf32>) { %c0 = arith.constant 0 : index // Non-register space alloca - should not be numbered // CHECK: memref.alloca() : memref<10xf32> - // CHECK-NOT: water.register_number + // CHECK-NOT: water.vgpr_number %local = memref.alloca() : memref<10xf32> // Register space alloca - should be numbered - // CHECK: memref.alloca() {water.register_count = 1 : i32, water.register_number = 0 : i32} + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} %reg = memref.alloca() : memref<1xf32, 128 : i32> return From 89c683d91b108ee34a3f2f1582951fb4abfb7a57 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 22:29:11 +0100 Subject: [PATCH 080/114] update water-opt Signed-off-by: Ivan Butygin --- water/tools/water-opt/water-opt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/water/tools/water-opt/water-opt.cpp b/water/tools/water-opt/water-opt.cpp index 2ab056b5b..dec98b61f 100644 --- a/water/tools/water-opt/water-opt.cpp +++ b/water/tools/water-opt/water-opt.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Transforms/Passes.h" @@ -49,6 +50,7 @@ void registerWaterTestDialect(DialectRegistry ®istry); int main(int argc, char **argv) { mlir::arith::registerArithIntRangeOptsPass(); + mlir::memref::registerExpandStridedMetadataPass(); mlir::registerCSEPass(); mlir::registerCanonicalizerPass(); mlir::registerCompositeFixedPointPass(); From 28f3ee1c02bc247223025a043f979beeddafdf9e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 22:29:44 +0100 Subject: [PATCH 081/114] some lowering Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 220 ++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 67915a3d8..ac2ad80fd 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -588,6 +588,195 @@ static bool usesWorkgroupAddressSpace(Value memref) { return false; } +/// Check if a memref uses register space (memspace 128) +static bool usesRegisterSpace(Value memref) { + auto memrefType = cast(memref.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return intAttr.getInt() == 128; + + return false; +} + +/// Lower memref.copy when destination is in register space - buffer variant +static LogicalResult lowerCopyToRegisterSpaceBuffer( + memref::CopyOp copyOp, IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprNum, unsigned vgprCount, unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + if (totalBits < 32) + return success(); + + FailureOr suffix = getBufferSuffixLoad(totalBits, isRDNAArch); + if (failed(suffix)) + return copyOp.emitError("unsupported buffer copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute byte offset (no indices for full copy) + Value offset = computeMemrefByteOffset<32>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Extract buffer descriptor and base offset + auto [bufferDesc, baseOffset] = extractBufferDescriptor(rewriter, loc, src); + Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, + arith::IntegerOverflowFlags::nsw); + + // Build constraint with specific VGPR + std::string outputConstraint; + if (vgprCount == 1) + outputConstraint = "={v" + std::to_string(vgprNum) + "}"; + else + outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + + std::to_string(vgprNum + vgprCount - 1) + "]}"; + std::string constraints = outputConstraint + ",v,s"; + + // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" + std::string asmStr = + ("buffer_load_" + *suffix + " $0, $1, $2, 0 offen").str(); + + createInlineAsm(rewriter, loc, resultType, + ValueRange{finalOffset, bufferDesc}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space - DS variant +static LogicalResult +lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, + unsigned vgprNum, unsigned vgprCount, + unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + if (totalBits < 32) + return success(); + + FailureOr suffix = getSizeSuffixLoad(totalBits); + if (failed(suffix)) + return copyOp.emitError("unsupported DS copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute byte offset + Value offset = computeMemrefByteOffset<32>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); + + // Build constraint with specific VGPR + std::string outputConstraint; + if (vgprCount == 1) + outputConstraint = "={v" + std::to_string(vgprNum) + "}"; + else + outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + + std::to_string(vgprNum + vgprCount - 1) + "]}"; + std::string constraints = outputConstraint + ",v"; + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space - global variant +static LogicalResult +lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, + unsigned vgprNum, unsigned vgprCount, + unsigned totalBits, Type resultType) { + Value src = copyOp.getSource(); + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + + if (totalBits < 32) + return success(); + + FailureOr suffix = getSizeSuffixLoad(totalBits); + if (failed(suffix)) + return copyOp.emitError("unsupported copy bit width: ") << totalBits; + + Location loc = copyOp.getLoc(); + rewriter.setInsertionPoint(copyOp); + + // Compute source address + Value addr = + computeMemrefAddress(rewriter, loc, src, /*indices=*/{}, elementBitWidth); + + // Build constraint with specific VGPR + std::string outputConstraint; + if (vgprCount == 1) + outputConstraint = "={v" + std::to_string(vgprNum) + "}"; + else + outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + + std::to_string(vgprNum + vgprCount - 1) + "]}"; + std::string constraints = outputConstraint + ",v"; + + // Build inline assembly: "global_load_b128 $0, $1, off" + std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); + + createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, asmStr, + constraints, /*hasSideEffects=*/true); + + rewriter.eraseOp(copyOp); + return success(); +} + +/// Lower memref.copy when destination is in register space +static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, + IRRewriter &rewriter, + bool isRDNAArch) { + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + + // Get destination alloca to find VGPR assignment + auto dstAlloca = dst.getDefiningOp(); + if (!dstAlloca) + return copyOp.emitError("destination must be a memref.alloca"); + + // Get VGPR number from destination alloca + auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + dstAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return copyOp.emitError("destination alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + // Get source type info + auto srcType = cast(src.getType()); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); + unsigned totalBits = elementBitWidth * vgprCount; + + // Get result type from destination + auto dstType = cast(dst.getType()); + Type resultType; + if (dstType.getShape().size() == 1 && dstType.getShape()[0] == 1) + resultType = dstType.getElementType(); + else + resultType = VectorType::get(dstType.getShape(), dstType.getElementType()); + + // Dispatch based on source memory space + if (usesBufferAddressSpace(src)) + return lowerCopyToRegisterSpaceBuffer(copyOp, rewriter, isRDNAArch, vgprNum, + vgprCount, totalBits, resultType); + if (usesWorkgroupAddressSpace(src)) + return lowerCopyToRegisterSpaceDS(copyOp, rewriter, vgprNum, vgprCount, + totalBits, resultType); + return lowerCopyToRegisterSpaceGlobal(copyOp, rewriter, vgprNum, vgprCount, + totalBits, resultType); +} + /// Pass that lowers high-level memory operations to AMDGPU memory instructions. /// Uses buffer operations for memrefs with /// #amdgpu.address_space, DS operations for memrefs with @@ -598,9 +787,30 @@ class WaterLowerMemoryOpsPass using Base::Base; void runOnOperation() override { + auto func = getOperation(); IRRewriter rewriter(&getContext()); - // Determine if we're targeting RDNA architecture + // Check if function has VGPR allocation and insert inline asm directive. + if (auto vgprAttr = func->getAttrOfType("water.total_vgprs")) { + unsigned vgprCount = vgprAttr.getInt(); + if (vgprCount > 0) { + unsigned vgprStart = 256 - vgprCount; + + // Insert inline assembly at the beginning of the function. + Block &entryBlock = func.getFunctionBody().front(); + rewriter.setInsertionPointToStart(&entryBlock); + + std::string asmStr = "var vgprCount = " + std::to_string(vgprCount) + + "\n" + + "var vgprStart = " + std::to_string(vgprStart); + + createInlineAsm(rewriter, func.getLoc(), /*resultTypes=*/{}, + /*operands=*/{}, asmStr, /*constraints=*/"", + /*hasSideEffects=*/true); + } + } + + // Determine if we're targeting RDNA architecture. bool isRDNAArch = isRDNA(chipset); // Helper to dispatch to the appropriate lowering function based on address @@ -655,6 +865,14 @@ class WaterLowerMemoryOpsPass return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto copyOp = dyn_cast(op)) { + // Only lower copy if destination is in register space + if (usesRegisterSpace(copyOp.getTarget())) { + if (failed(lowerCopyToRegisterSpace(copyOp, rewriter, isRDNAArch))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + } return WalkResult::advance(); }; From 6bcb1faf8c1fec140b58b862d8d47fc23cd88fec Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 14 Dec 2025 22:35:38 +0100 Subject: [PATCH 082/114] cleanup alloca Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index ac2ad80fd..08f63b81a 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -878,6 +878,23 @@ class WaterLowerMemoryOpsPass if (getOperation()->walk(walkFn).wasInterrupted()) signalPassFailure(); + + // Clean up register space allocas - they should all be lowered by now + WalkResult cleanupResult = + getOperation()->walk([&](memref::AllocaOp allocaOp) { + if (usesRegisterSpace(allocaOp.getMemref())) { + if (!allocaOp->use_empty()) { + allocaOp->emitError("register space alloca still has uses after " + "lowering - not all operations were lowered"); + return WalkResult::interrupt(); + } + rewriter.eraseOp(allocaOp); + } + return WalkResult::advance(); + }); + + if (cleanupResult.wasInterrupted()) + signalPassFailure(); } }; From c6b1316b83fba99310c20b2d99a407ffe4753225 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 00:03:06 +0100 Subject: [PATCH 083/114] copy to reg space lowering Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 148 ++++++++++++++++--- 1 file changed, 125 insertions(+), 23 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 08f63b81a..59c69c831 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -354,6 +354,9 @@ template static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + // TODO: for bitwidths less than 32, we will need to truncate the value to 32 + // immediately after the load, breaking the calculated dependencies. + // For now, just let llvm handle the loading if (bitWidth < 32) return success(); @@ -607,9 +610,6 @@ static LogicalResult lowerCopyToRegisterSpaceBuffer( auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - if (totalBits < 32) - return success(); - FailureOr suffix = getBufferSuffixLoad(totalBits, isRDNAArch); if (failed(suffix)) return copyOp.emitError("unsupported buffer copy bit width: ") << totalBits; @@ -656,9 +656,6 @@ lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - if (totalBits < 32) - return success(); - FailureOr suffix = getSizeSuffixLoad(totalBits); if (failed(suffix)) return copyOp.emitError("unsupported DS copy bit width: ") << totalBits; @@ -698,9 +695,6 @@ lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - if (totalBits < 32) - return success(); - FailureOr suffix = getSizeSuffixLoad(totalBits); if (failed(suffix)) return copyOp.emitError("unsupported copy bit width: ") << totalBits; @@ -777,6 +771,108 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, totalBits, resultType); } +/// Lower load from register space to inline assembly +template +static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, + IRRewriter &rewriter) { + Value memref; + if constexpr (std::is_same_v) + memref = loadOp.getBase(); + else + memref = loadOp.getMemRef(); + + // Get source alloca to find VGPR assignment + auto srcAlloca = memref.getDefiningOp(); + if (!srcAlloca) + return loadOp.emitError("source must be a memref.alloca"); + + // Get VGPR number from source alloca + auto vgprNumAttr = srcAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + srcAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return loadOp.emitError("source alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build constraint for reading from specific VGPR(s) + std::string inputConstraint; + if (vgprCount == 1) + inputConstraint = "{v" + std::to_string(vgprNum) + "}"; + else + inputConstraint = "{v[" + std::to_string(vgprNum) + ":" + + std::to_string(vgprNum + vgprCount - 1) + "]}"; + std::string constraints = "=" + inputConstraint; + + // Simple v_mov to read from VGPR (compiler will optimize this away) + std::string asmStr = "; reg_load"; + + Type resultType = loadOp.getResult().getType(); + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{}, asmStr, + constraints, /*hasSideEffects=*/false); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + +/// Lower store to register space to inline assembly +template +static LogicalResult lowerStoreToRegisterSpace(StoreOpTy storeOp, + IRRewriter &rewriter) { + Value memref; + if constexpr (std::is_same_v) + memref = storeOp.getBase(); + else + memref = storeOp.getMemRef(); + + // Get destination alloca to find VGPR assignment + auto dstAlloca = memref.getDefiningOp(); + if (!dstAlloca) + return storeOp.emitError("destination must be a memref.alloca"); + + // Get VGPR number from destination alloca + auto vgprNumAttr = dstAlloca->getAttrOfType("water.vgpr_number"); + auto vgprCountAttr = + dstAlloca->getAttrOfType("water.vgpr_count"); + if (!vgprNumAttr || !vgprCountAttr) + return storeOp.emitError("destination alloca missing VGPR attributes"); + + unsigned vgprNum = vgprNumAttr.getInt(); + unsigned vgprCount = vgprCountAttr.getInt(); + + Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + + // Build constraint for writing to specific VGPR(s) + std::string outputConstraint; + if (vgprCount == 1) + outputConstraint = "={v" + std::to_string(vgprNum) + "}"; + else + outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + + std::to_string(vgprNum + vgprCount - 1) + "]}"; + std::string constraints = outputConstraint + ",0"; + + // v_mov to write to VGPR (input constraint 0 ties to output) + std::string asmStr = "; reg_store"; + + Value valueToStore; + if constexpr (std::is_same_v) + valueToStore = storeOp.getValueToStore(); + else + valueToStore = storeOp.getValueToStore(); + + createInlineAsm(rewriter, loc, valueToStore.getType(), + ValueRange{valueToStore}, asmStr, constraints, + /*hasSideEffects=*/true); + + rewriter.eraseOp(storeOp); + return success(); +} + /// Pass that lowers high-level memory operations to AMDGPU memory instructions. /// Uses buffer operations for memrefs with /// #amdgpu.address_space, DS operations for memrefs with @@ -815,8 +911,11 @@ class WaterLowerMemoryOpsPass // Helper to dispatch to the appropriate lowering function based on address // space - auto lowerMemoryOp = [&](Value base, auto lowerBuffer, auto lowerWorkgroup, + auto lowerMemoryOp = [&](Value base, auto lowerRegister, auto lowerBuffer, + auto lowerWorkgroup, auto lowerGlobal) -> LogicalResult { + if (usesRegisterSpace(base)) + return lowerRegister(); if (usesBufferAddressSpace(base)) return lowerBuffer(); if (usesWorkgroupAddressSpace(base)) @@ -828,6 +927,7 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getBase(), + [&]() { return lowerLoadFromRegisterSpace(loadOp, rewriter); }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter); }, [&]() { return lowerLoadGlobal(loadOp, rewriter); }); @@ -838,6 +938,7 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getBase(), + [&]() { return lowerStoreToRegisterSpace(storeOp, rewriter); }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter); }, [&]() { return lowerStoreGlobal(storeOp, rewriter); }); @@ -848,6 +949,7 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getMemRef(), + [&]() { return lowerLoadFromRegisterSpace(loadOp, rewriter); }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter); }, [&]() { return lowerLoadGlobal(loadOp, rewriter); }); @@ -858,6 +960,7 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getMemRef(), + [&]() { return lowerStoreToRegisterSpace(storeOp, rewriter); }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter); }, [&]() { return lowerStoreGlobal(storeOp, rewriter); }); @@ -876,22 +979,21 @@ class WaterLowerMemoryOpsPass return WalkResult::advance(); }; - if (getOperation()->walk(walkFn).wasInterrupted()) + if (func.walk(walkFn).wasInterrupted()) signalPassFailure(); // Clean up register space allocas - they should all be lowered by now - WalkResult cleanupResult = - getOperation()->walk([&](memref::AllocaOp allocaOp) { - if (usesRegisterSpace(allocaOp.getMemref())) { - if (!allocaOp->use_empty()) { - allocaOp->emitError("register space alloca still has uses after " - "lowering - not all operations were lowered"); - return WalkResult::interrupt(); - } - rewriter.eraseOp(allocaOp); - } - return WalkResult::advance(); - }); + WalkResult cleanupResult = func.walk([&](memref::AllocaOp allocaOp) { + if (usesRegisterSpace(allocaOp.getMemref())) { + if (!allocaOp->use_empty()) { + allocaOp->emitError("register space alloca still has uses after " + "lowering - not all operations were lowered"); + return WalkResult::interrupt(); + } + rewriter.eraseOp(allocaOp); + } + return WalkResult::advance(); + }); if (cleanupResult.wasInterrupted()) signalPassFailure(); From cf72eb2615354054f73e8885082fdbda9f94a9e0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 00:21:54 +0100 Subject: [PATCH 084/114] reg lowering fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 157 ++++++++++--------- 1 file changed, 79 insertions(+), 78 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 59c69c831..16f7cf4fa 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -153,6 +153,10 @@ static FailureOr getBufferSuffixLoad(unsigned bitWidth, if (isRDNAArch) { // RDNA uses b32, b64, etc. switch (bitWidth) { + case 8: + return StringRef("u8"); + case 16: + return StringRef("u16"); case 32: return StringRef("b32"); case 64: @@ -602,10 +606,23 @@ static bool usesRegisterSpace(Value memref) { return false; } +static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, bool isOutput) { + std::string constraint; + if (vgprCount == 1) + constraint = "{v" + std::to_string(vgprOffset + vgprNum) + "}"; + else + constraint = "{v[" + std::to_string(vgprOffset + vgprNum) + ":" + + std::to_string(vgprOffset + vgprNum + vgprCount - 1) + "]}"; + return isOutput ? "=" + constraint : constraint; +} + /// Lower memref.copy when destination is in register space - buffer variant -static LogicalResult lowerCopyToRegisterSpaceBuffer( - memref::CopyOp copyOp, IRRewriter &rewriter, bool isRDNAArch, - unsigned vgprNum, unsigned vgprCount, unsigned totalBits, Type resultType) { +static LogicalResult +lowerCopyToRegisterSpaceBuffer(memref::CopyOp copyOp, IRRewriter &rewriter, + bool isRDNAArch, unsigned vgprOffset, + unsigned vgprNum, unsigned vgprCount, + unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); @@ -627,13 +644,8 @@ static LogicalResult lowerCopyToRegisterSpaceBuffer( arith::IntegerOverflowFlags::nsw); // Build constraint with specific VGPR - std::string outputConstraint; - if (vgprCount == 1) - outputConstraint = "={v" + std::to_string(vgprNum) + "}"; - else - outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + - std::to_string(vgprNum + vgprCount - 1) + "]}"; - std::string constraints = outputConstraint + ",v,s"; + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v,s"; // Build inline assembly: "buffer_load_ $0, $1, $2, 0 offen" std::string asmStr = @@ -648,10 +660,9 @@ static LogicalResult lowerCopyToRegisterSpaceBuffer( } /// Lower memref.copy when destination is in register space - DS variant -static LogicalResult -lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, - unsigned vgprNum, unsigned vgprCount, - unsigned totalBits, Type resultType) { +static LogicalResult lowerCopyToRegisterSpaceDS( + memref::CopyOp copyOp, IRRewriter &rewriter, unsigned vgprOffset, + unsigned vgprNum, unsigned vgprCount, unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); @@ -668,13 +679,8 @@ lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, elementBitWidth); // Build constraint with specific VGPR - std::string outputConstraint; - if (vgprCount == 1) - outputConstraint = "={v" + std::to_string(vgprNum) + "}"; - else - outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + - std::to_string(vgprNum + vgprCount - 1) + "]}"; - std::string constraints = outputConstraint + ",v"; + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; // Build inline assembly: "ds_read_b32 $0, $1" std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); @@ -687,10 +693,9 @@ lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, } /// Lower memref.copy when destination is in register space - global variant -static LogicalResult -lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, - unsigned vgprNum, unsigned vgprCount, - unsigned totalBits, Type resultType) { +static LogicalResult lowerCopyToRegisterSpaceGlobal( + memref::CopyOp copyOp, IRRewriter &rewriter, unsigned vgprOffset, + unsigned vgprNum, unsigned vgprCount, unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); @@ -707,13 +712,8 @@ lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, computeMemrefAddress(rewriter, loc, src, /*indices=*/{}, elementBitWidth); // Build constraint with specific VGPR - std::string outputConstraint; - if (vgprCount == 1) - outputConstraint = "={v" + std::to_string(vgprNum) + "}"; - else - outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + - std::to_string(vgprNum + vgprCount - 1) + "]}"; - std::string constraints = outputConstraint + ",v"; + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",v"; // Build inline assembly: "global_load_b128 $0, $1, off" std::string asmStr = ("global_load_" + *suffix + " $0, $1, off").str(); @@ -728,7 +728,8 @@ lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, /// Lower memref.copy when destination is in register space static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, IRRewriter &rewriter, - bool isRDNAArch) { + bool isRDNAArch, + unsigned vgprOffset) { Value src = copyOp.getSource(); Value dst = copyOp.getTarget(); @@ -762,19 +763,21 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, // Dispatch based on source memory space if (usesBufferAddressSpace(src)) - return lowerCopyToRegisterSpaceBuffer(copyOp, rewriter, isRDNAArch, vgprNum, - vgprCount, totalBits, resultType); + return lowerCopyToRegisterSpaceBuffer(copyOp, rewriter, isRDNAArch, + vgprOffset, vgprNum, vgprCount, + totalBits, resultType); if (usesWorkgroupAddressSpace(src)) - return lowerCopyToRegisterSpaceDS(copyOp, rewriter, vgprNum, vgprCount, - totalBits, resultType); - return lowerCopyToRegisterSpaceGlobal(copyOp, rewriter, vgprNum, vgprCount, - totalBits, resultType); + return lowerCopyToRegisterSpaceDS(copyOp, rewriter, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); + return lowerCopyToRegisterSpaceGlobal(copyOp, rewriter, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); } /// Lower load from register space to inline assembly template static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, - IRRewriter &rewriter) { + IRRewriter &rewriter, + unsigned vgprOffset) { Value memref; if constexpr (std::is_same_v) memref = loadOp.getBase(); @@ -800,13 +803,8 @@ static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, rewriter.setInsertionPoint(loadOp); // Build constraint for reading from specific VGPR(s) - std::string inputConstraint; - if (vgprCount == 1) - inputConstraint = "{v" + std::to_string(vgprNum) + "}"; - else - inputConstraint = "{v[" + std::to_string(vgprNum) + ":" + - std::to_string(vgprNum + vgprCount - 1) + "]}"; - std::string constraints = "=" + inputConstraint; + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true); // Simple v_mov to read from VGPR (compiler will optimize this away) std::string asmStr = "; reg_load"; @@ -822,7 +820,8 @@ static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, /// Lower store to register space to inline assembly template static LogicalResult lowerStoreToRegisterSpace(StoreOpTy storeOp, - IRRewriter &rewriter) { + IRRewriter &rewriter, + unsigned vgprOffset) { Value memref; if constexpr (std::is_same_v) memref = storeOp.getBase(); @@ -848,13 +847,8 @@ static LogicalResult lowerStoreToRegisterSpace(StoreOpTy storeOp, rewriter.setInsertionPoint(storeOp); // Build constraint for writing to specific VGPR(s) - std::string outputConstraint; - if (vgprCount == 1) - outputConstraint = "={v" + std::to_string(vgprNum) + "}"; - else - outputConstraint = "={v[" + std::to_string(vgprNum) + ":" + - std::to_string(vgprNum + vgprCount - 1) + "]}"; - std::string constraints = outputConstraint + ",0"; + std::string constraints = + getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",0"; // v_mov to write to VGPR (input constraint 0 ties to output) std::string asmStr = "; reg_store"; @@ -884,26 +878,24 @@ class WaterLowerMemoryOpsPass void runOnOperation() override { auto func = getOperation(); - IRRewriter rewriter(&getContext()); // Check if function has VGPR allocation and insert inline asm directive. - if (auto vgprAttr = func->getAttrOfType("water.total_vgprs")) { - unsigned vgprCount = vgprAttr.getInt(); - if (vgprCount > 0) { - unsigned vgprStart = 256 - vgprCount; - - // Insert inline assembly at the beginning of the function. - Block &entryBlock = func.getFunctionBody().front(); - rewriter.setInsertionPointToStart(&entryBlock); - - std::string asmStr = "var vgprCount = " + std::to_string(vgprCount) + - "\n" + - "var vgprStart = " + std::to_string(vgprStart); - - createInlineAsm(rewriter, func.getLoc(), /*resultTypes=*/{}, - /*operands=*/{}, asmStr, /*constraints=*/"", - /*hasSideEffects=*/true); - } + auto vgprAttr = func->getAttrOfType("water.total_vgprs"); + unsigned vgprCount = vgprAttr ? vgprAttr.getInt() : 0; + unsigned vgprStart = 256 - vgprCount; + + // Insert inline assembly at the beginning of the function. + Block &entryBlock = func.getFunctionBody().front(); + IRRewriter rewriter(&getContext()); + rewriter.setInsertionPointToStart(&entryBlock); + + if (vgprCount > 0) { + std::string asmStr = "; vgprCount = " + std::to_string(vgprCount) + + " vgprStart = " + std::to_string(vgprStart); + + createInlineAsm(rewriter, func.getLoc(), /*resultTypes=*/{}, + /*operands=*/{}, asmStr, /*constraints=*/"", + /*hasSideEffects=*/true); } // Determine if we're targeting RDNA architecture. @@ -927,7 +919,9 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getBase(), - [&]() { return lowerLoadFromRegisterSpace(loadOp, rewriter); }, + [&]() { + return lowerLoadFromRegisterSpace(loadOp, rewriter, vgprStart); + }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter); }, [&]() { return lowerLoadGlobal(loadOp, rewriter); }); @@ -938,7 +932,9 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getBase(), - [&]() { return lowerStoreToRegisterSpace(storeOp, rewriter); }, + [&]() { + return lowerStoreToRegisterSpace(storeOp, rewriter, vgprStart); + }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter); }, [&]() { return lowerStoreGlobal(storeOp, rewriter); }); @@ -949,7 +945,9 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getMemRef(), - [&]() { return lowerLoadFromRegisterSpace(loadOp, rewriter); }, + [&]() { + return lowerLoadFromRegisterSpace(loadOp, rewriter, vgprStart); + }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter); }, [&]() { return lowerLoadGlobal(loadOp, rewriter); }); @@ -960,7 +958,9 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getMemRef(), - [&]() { return lowerStoreToRegisterSpace(storeOp, rewriter); }, + [&]() { + return lowerStoreToRegisterSpace(storeOp, rewriter, vgprStart); + }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter); }, [&]() { return lowerStoreGlobal(storeOp, rewriter); }); @@ -971,7 +971,8 @@ class WaterLowerMemoryOpsPass if (auto copyOp = dyn_cast(op)) { // Only lower copy if destination is in register space if (usesRegisterSpace(copyOp.getTarget())) { - if (failed(lowerCopyToRegisterSpace(copyOp, rewriter, isRDNAArch))) + if (failed(lowerCopyToRegisterSpace(copyOp, rewriter, isRDNAArch, + vgprStart))) return WalkResult::interrupt(); return WalkResult::advance(); } From 744935f4f6e941b83eab4385f9472564a1ae3cda Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 02:23:02 +0100 Subject: [PATCH 085/114] bufer fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 16f7cf4fa..6e4c07d4a 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -250,8 +250,6 @@ extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { MemRefDescriptor memrefDesc(memrefDescVal); Value bufferPtr = memrefDesc.alignedPtr(rewriter, loc); - Value bufferOffset = memrefDesc.offset(rewriter, loc); - bufferOffset = arith::TruncIOp::create(rewriter, loc, i32Type, bufferOffset); // Convert to i160 to access full buffer descriptor {<4 x i32> rsrc, i32 // offset} @@ -261,9 +259,6 @@ extractBufferDescriptor(IRRewriter &rewriter, Location loc, Value memref) { // Extract lower 32 bits for base offset Value baseOffset = arith::TruncIOp::create(rewriter, loc, i32Type, fullDesc); - baseOffset = arith::AddIOp::create(rewriter, loc, baseOffset, bufferOffset, - arith::IntegerOverflowFlags::nsw); - // Extract upper 128 bits for resource descriptor auto c32 = arith::ConstantIntOp::create(rewriter, loc, i160Type, 32); Value rsrcBits160 = arith::ShRUIOp::create(rewriter, loc, fullDesc, c32); @@ -750,8 +745,11 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, // Get source type info auto srcType = cast(src.getType()); + if (!srcType.hasStaticShape()) + return copyOp.emitError("source must have static shape"); + unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - unsigned totalBits = elementBitWidth * vgprCount; + unsigned totalBits = srcType.getNumElements() * elementBitWidth; // Get result type from destination auto dstType = cast(dst.getType()); From ee528b8d94d0855d778826f39967417e677a541c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 02:31:10 +0100 Subject: [PATCH 086/114] fix shapes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 6e4c07d4a..07348487f 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -753,11 +753,15 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, // Get result type from destination auto dstType = cast(dst.getType()); + if (!dstType.hasStaticShape()) + return copyOp.emitError("destination must have static shape"); + Type resultType; - if (dstType.getShape().size() == 1 && dstType.getShape()[0] == 1) + if (dstType.getNumElements() == 1) resultType = dstType.getElementType(); else - resultType = VectorType::get(dstType.getShape(), dstType.getElementType()); + resultType = + VectorType::get(dstType.getNumElements(), dstType.getElementType()); // Dispatch based on source memory space if (usesBufferAddressSpace(src)) From e315e7c454de7d51a4b97972efa0bc19adc6e2ed Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 03:06:43 +0100 Subject: [PATCH 087/114] types fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 71 +++++++++++++------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 07348487f..48be50bc9 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -26,6 +26,18 @@ namespace mlir::water { namespace { +static unsigned getBitwidth(ShapedType type) { + assert(type.hasStaticShape() && "Shaped type must have static shape"); + return type.getNumElements() * type.getElementTypeBitWidth(); +} + +static unsigned getBitwidth(Type type) { + if (auto shaped = dyn_cast(type)) + return getBitwidth(shaped); + + return type.getIntOrFloatBitWidth(); +} + /// Get the AMDGPU instruction suffix based on bit width (for loads - unsigned) static FailureOr getSizeSuffixLoad(unsigned bitWidth) { switch (bitWidth) { @@ -274,12 +286,11 @@ template static std::tuple getLoadOpInfo(LoadOpTy loadOp) { if constexpr (std::is_same_v) { auto vectorType = loadOp.getVectorType(); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + unsigned bitWidth = getBitwidth(vectorType); return {loadOp.getBase(), vectorType, bitWidth}; } else { auto elementType = loadOp.getResult().getType(); - unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + unsigned bitWidth = getBitwidth(elementType); return {loadOp.getMemRef(), elementType, bitWidth}; } } @@ -289,12 +300,11 @@ template static std::tuple getStoreOpInfo(StoreOpTy storeOp) { if constexpr (std::is_same_v) { auto vectorType = cast(storeOp.getValueToStore().getType()); - unsigned bitWidth = - vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); + unsigned bitWidth = getBitwidth(vectorType); return {storeOp.getBase(), vectorType, bitWidth}; } else { auto elementType = storeOp.getValueToStore().getType(); - unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + unsigned bitWidth = getBitwidth(elementType); return {storeOp.getMemRef(), elementType, bitWidth}; } } @@ -748,20 +758,17 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, if (!srcType.hasStaticShape()) return copyOp.emitError("source must have static shape"); - unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - unsigned totalBits = srcType.getNumElements() * elementBitWidth; + unsigned totalBits = getBitwidth(srcType); // Get result type from destination auto dstType = cast(dst.getType()); if (!dstType.hasStaticShape()) return copyOp.emitError("destination must have static shape"); - Type resultType; - if (dstType.getNumElements() == 1) - resultType = dstType.getElementType(); - else - resultType = - VectorType::get(dstType.getNumElements(), dstType.getElementType()); + unsigned resultBitWidth = getBitwidth(dstType); + unsigned resultNumElements = (resultBitWidth + 31) / 32; + Type resultType = + VectorType::get(resultNumElements, rewriter.getIntegerType(32)); // Dispatch based on source memory space if (usesBufferAddressSpace(src)) @@ -812,10 +819,22 @@ static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, std::string asmStr = "; reg_load"; Type resultType = loadOp.getResult().getType(); - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{}, asmStr, - constraints, /*hasSideEffects=*/false); + Type asmType = resultType; + unsigned bitWidth = getBitwidth(resultType); + if (bitWidth < 32) + asmType = rewriter.getIntegerType(32); - rewriter.replaceOp(loadOp, asmOp.getResult(0)); + Value asmResult = createInlineAsm(rewriter, loc, asmType, {}, asmStr, + constraints, /*hasSideEffects=*/false) + .getResult(0); + + if (bitWidth < 32) { + auto narrowType = rewriter.getIntegerType(bitWidth); + asmResult = arith::TruncIOp::create(rewriter, loc, narrowType, asmResult); + asmResult = LLVM::BitcastOp::create(rewriter, loc, resultType, asmResult); + } + + rewriter.replaceOp(loadOp, asmResult); return success(); } @@ -855,14 +874,18 @@ static LogicalResult lowerStoreToRegisterSpace(StoreOpTy storeOp, // v_mov to write to VGPR (input constraint 0 ties to output) std::string asmStr = "; reg_store"; - Value valueToStore; - if constexpr (std::is_same_v) - valueToStore = storeOp.getValueToStore(); - else - valueToStore = storeOp.getValueToStore(); + Value valueToStore = storeOp.getValueToStore(); + unsigned bitWidth = getBitwidth(valueToStore.getType()); + if (bitWidth < 32) { + auto intType = rewriter.getIntegerType(bitWidth); + valueToStore = + LLVM::BitcastOp::create(rewriter, loc, intType, valueToStore); + auto i32Type = rewriter.getIntegerType(32); + valueToStore = arith::ExtUIOp::create(rewriter, loc, i32Type, valueToStore); + } - createInlineAsm(rewriter, loc, valueToStore.getType(), - ValueRange{valueToStore}, asmStr, constraints, + createInlineAsm(rewriter, loc, valueToStore.getType(), valueToStore, asmStr, + constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); From 03916de9c0ade2c94f718fde7a9989862f01a2ca Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 12:56:45 +0100 Subject: [PATCH 088/114] set amdgpu-num-vgpr Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 23 +++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 48be50bc9..ffa9fb937 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -903,15 +903,36 @@ class WaterLowerMemoryOpsPass void runOnOperation() override { auto func = getOperation(); + MLIRContext *ctx = &getContext(); // Check if function has VGPR allocation and insert inline asm directive. auto vgprAttr = func->getAttrOfType("water.total_vgprs"); unsigned vgprCount = vgprAttr ? vgprAttr.getInt() : 0; unsigned vgprStart = 256 - vgprCount; + if (vgprCount > 0) { + // Add amdgpu-num-vgpr to passthrough attribute list + auto vgprStartAttr = StringAttr::get(ctx, std::to_string(vgprStart)); + auto nameAttr = StringAttr::get(ctx, "amdgpu-num-vgpr"); + + Attribute passthroughAttr; + // Get existing passthrough or create new one + if (auto existingPassthrough = + func->getAttrOfType("passthrough")) { + SmallVector attrs(existingPassthrough.begin(), + existingPassthrough.end()); + attrs.push_back(ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})); + passthroughAttr = ArrayAttr::get(ctx, attrs); + } else { + passthroughAttr = ArrayAttr::get( + ctx, {ArrayAttr::get(ctx, {nameAttr, vgprStartAttr})}); + } + func->setAttr("passthrough", passthroughAttr); + } + // Insert inline assembly at the beginning of the function. Block &entryBlock = func.getFunctionBody().front(); - IRRewriter rewriter(&getContext()); + IRRewriter rewriter(ctx); rewriter.setInsertionPointToStart(&entryBlock); if (vgprCount > 0) { From b4e358273e4959803e7987169b71bb8a2ef43d3d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 13:21:23 +0100 Subject: [PATCH 089/114] reg lowering tests Signed-off-by: Ivan Butygin --- water/test/Transforms/lower-memory-ops.mlir | 97 +++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 7dbd51ab4..3c1e410e3 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -338,3 +338,100 @@ func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { return } + +// Test copy to register space with pre-numbered allocas + +// CHECK-LABEL: func.func @copy_global_to_reg_scalar +// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes {water.total_vgprs = 1 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" + memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v255}" + %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK-NOT: memref.alloca + return %val : f32 +} + +// CHECK-LABEL: func.func @copy_global_to_reg_vector +// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @copy_buffer_to_reg +// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #amdgpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> + // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> to memref<4xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @copy_workgroup_to_reg +// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val : vector<4xf32> +} + +// CHECK-LABEL: func.func @store_to_reg +// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +func.func @store_to_reg(%val: f32) -> f32 attributes {water.total_vgprs = 1 : i32} { + %c0 = arith.constant 0 : index + %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + // CHECK: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" + memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v255}" + %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> + // CHECK-NOT: memref.alloca + return %result : f32 +} + +// CHECK-LABEL: func.func @multiple_reg_allocas +// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, #gpu.address_space>) -> (f32, vector<4xf32>, vector<4xf32>) attributes {water.total_vgprs = 9 : i32} { + %c0 = arith.constant 0 : index + %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> + %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> + // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" + %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> + memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> + // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" + %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> + memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> + // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> + memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v247}" + %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> + // CHECK: llvm.inline_asm "; reg_load", "={v[248:251]}" + %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK-NOT: memref.alloca + return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> +} From 2cd0889a1d01aec75dfccfc8717cbf1d7224d370 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 17:19:29 +0100 Subject: [PATCH 090/114] chipset Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 27 ++++++++++++-------- water/test/Transforms/lower-memory-ops.mlir | 2 +- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index ffa9fb937..bad1eddbb 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -92,9 +93,9 @@ static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, /*operand_attrs=*/ArrayAttr{}); } -/// Detect if chipset is RDNA architecture -static bool isRDNA(StringRef chipset) { - return chipset.starts_with("gfx11") || chipset.starts_with("gfx12"); +/// Detect if chipset is RDNA vs CDNA architecture +static bool isRDNA(const amdgpu::Chipset &chipset) { + return chipset.majorVersion != 9; } /// Compute byte offset as iX for a memref access with indices @@ -892,10 +893,6 @@ static LogicalResult lowerStoreToRegisterSpace(StoreOpTy storeOp, return success(); } -/// Pass that lowers high-level memory operations to AMDGPU memory instructions. -/// Uses buffer operations for memrefs with -/// #amdgpu.address_space, DS operations for memrefs with -/// #gpu.address_space, and global operations for all other memrefs. class WaterLowerMemoryOpsPass : public water::impl::WaterLowerMemoryOpsBase { public: @@ -903,12 +900,21 @@ class WaterLowerMemoryOpsPass void runOnOperation() override { auto func = getOperation(); + auto chip = amdgpu::Chipset::parse(chipset); + if (failed(chip)) { + func->emitError("invalid chipset: ") << chipset; + return signalPassFailure(); + } + MLIRContext *ctx = &getContext(); + // Assume Wave32 for gfx12+ for now. + unsigned totalVGPRs = chip->majorVersion >= 12 ? 1024 : 256; + // Check if function has VGPR allocation and insert inline asm directive. auto vgprAttr = func->getAttrOfType("water.total_vgprs"); unsigned vgprCount = vgprAttr ? vgprAttr.getInt() : 0; - unsigned vgprStart = 256 - vgprCount; + unsigned vgprStart = totalVGPRs - vgprCount; if (vgprCount > 0) { // Add amdgpu-num-vgpr to passthrough attribute list @@ -944,8 +950,9 @@ class WaterLowerMemoryOpsPass /*hasSideEffects=*/true); } - // Determine if we're targeting RDNA architecture. - bool isRDNAArch = isRDNA(chipset); + // Determine if we're targeting RDNA vs CDNA architecture, CDNA has + // different buffer ops format. + bool isRDNAArch = isRDNA(*chip); // Helper to dispatch to the appropriate lowering function based on address // space diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 3c1e410e3..49732d169 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -1,4 +1,4 @@ -// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops))' | FileCheck %s +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx950}))' | FileCheck %s // Test lowering of vector memory operations to AMDGPU global_load/store inline assembly From 0dedccd6abda947a8f3f506dc5301c0bc892c898 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 17:59:03 +0100 Subject: [PATCH 091/114] rdna vs cdna tests Signed-off-by: Ivan Butygin --- water/test/Transforms/lower-memory-ops.mlir | 108 +++++++++++++------- 1 file changed, 72 insertions(+), 36 deletions(-) diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 49732d169..5bea30313 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -1,4 +1,5 @@ -// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx950}))' | FileCheck %s +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx950}))' | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-lower-memory-ops{chipset=gfx1200}))' | FileCheck %s --check-prefixes=CHECK,GFX12 // Test lowering of vector memory operations to AMDGPU global_load/store inline assembly @@ -105,56 +106,64 @@ func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, % // CHECK-LABEL: func.func @buffer_load_b32 func.func @buffer_load_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<1xf32> { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> return %result : vector<1xf32> } // CHECK-LABEL: func.func @buffer_load_b64 func.func @buffer_load_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<2xf32> { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx2 $0, $1, $2, 0 offen", "=v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx2 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b64 $0, $1, $2, 0 offen", "=v,v,s" %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> return %result : vector<2xf32> } // CHECK-LABEL: func.func @buffer_load_b96 func.func @buffer_load_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<3xf32> { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx3 $0, $1, $2, 0 offen", "=v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx3 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b96 $0, $1, $2, 0 offen", "=v,v,s" %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> return %result : vector<3xf32> } // CHECK-LABEL: func.func @buffer_load_b128 func.func @buffer_load_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> vector<4xf32> { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" %result = vector.load %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> return %result : vector<4xf32> } // CHECK-LABEL: func.func @buffer_store_b32 func.func @buffer_store_b32(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<1xf32>) { - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<1xf32> return } // CHECK-LABEL: func.func @buffer_store_b64 func.func @buffer_store_b64(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<2xf32>) { - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx2 $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx2 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b64 $0, $1, $2, 0 offen", "v,v,s" vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<2xf32> return } // CHECK-LABEL: func.func @buffer_store_b96 func.func @buffer_store_b96(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<3xf32>) { - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx3 $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx3 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b96 $0, $1, $2, 0 offen", "v,v,s" vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<3xf32> return } // CHECK-LABEL: func.func @buffer_store_b128 func.func @buffer_store_b128(%memref: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: vector<4xf32>) { - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" vector.store %data, %memref[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> return } @@ -166,11 +175,13 @@ func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<10 %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> // Store to buffer memory (should use buffer_store) - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" vector.store %global_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> // Load from buffer memory (should use buffer_load) - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "=v,v,s" %buffer_data = vector.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> // Store to global memory (should use global_store) @@ -253,7 +264,8 @@ func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref %lds_data = vector.load %lds[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> // Store to buffer (should use buffer_store) - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dwordx4 $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b128 $0, $1, $2, 0 offen", "v,v,s" vector.store %lds_data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> return @@ -292,14 +304,16 @@ func.func @scalar_store_global_f64(%memref: memref<1024xf64>, %offset: index, %d // CHECK-LABEL: func.func @scalar_load_buffer_f32 func.func @scalar_load_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) -> f32 { - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dword $0, $1, $2, 0 offen", "=v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b32 $0, $1, $2, 0 offen", "=v,v,s" %result = memref.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> return %result : f32 } // CHECK-LABEL: func.func @scalar_store_buffer_f32 func.func @scalar_store_buffer_f32(%buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index, %data: f32) { - // CHECK: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_store_dword $0, $1, $2, 0 offen", "v,v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_store_b32 $0, $1, $2, 0 offen", "v,v,s" memref.store %data, %buffer[%offset] : memref<1024xf32, #amdgpu.address_space> return } @@ -342,95 +356,117 @@ func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { // Test copy to register space with pre-numbered allocas // CHECK-LABEL: func.func @copy_global_to_reg_scalar -// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1023"]] func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes {water.total_vgprs = 1 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v1023},v" memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v255}" + // GFX9: llvm.inline_asm "; reg_load", "={v255}" + // GFX12: llvm.inline_asm "; reg_load", "={v1023}" %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> // CHECK-NOT: memref.alloca return %val : f32 } // CHECK-LABEL: func.func @copy_global_to_reg_vector -// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1020"]] func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[1020:1023]},v" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> } // CHECK-LABEL: func.func @copy_buffer_to_reg -// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1020"]] func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #amdgpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> - // CHECK: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "={v[1020:1023]},v,s" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> to memref<4xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> } // CHECK-LABEL: func.func @copy_workgroup_to_reg -// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1020"]] func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> - // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[1020:1023]},v" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> } // CHECK-LABEL: func.func @store_to_reg -// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1023"]] func.func @store_to_reg(%val: f32) -> f32 attributes {water.total_vgprs = 1 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> - // CHECK: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" + // GFX9: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" + // GFX12: llvm.inline_asm has_side_effects "; reg_store", "={v1023},0" memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v255}" + // GFX9: llvm.inline_asm "; reg_load", "={v255}" + // GFX12: llvm.inline_asm "; reg_load", "={v1023}" %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> // CHECK-NOT: memref.alloca return %result : f32 } // CHECK-LABEL: func.func @multiple_reg_allocas -// CHECK-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +// GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1015"]] func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, #gpu.address_space>) -> (f32, vector<4xf32>, vector<4xf32>) attributes {water.total_vgprs = 9 : i32} { %c0 = arith.constant 0 : index %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v1015},v" %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[1016:1019]},v" %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> - // CHECK: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[1020:1023]},v" %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v247}" + // GFX9: llvm.inline_asm "; reg_load", "={v247}" + // GFX12: llvm.inline_asm "; reg_load", "={v1015}" %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> - // CHECK: llvm.inline_asm "; reg_load", "={v[248:251]}" + // GFX9: llvm.inline_asm "; reg_load", "={v[248:251]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[1016:1019]}" %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // CHECK: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> From 9aa0cb6892538315bcffbc73f450d621ae0c2985 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 15 Dec 2025 19:10:26 +0100 Subject: [PATCH 092/114] lowering pipeline Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 306954e06..790aacec8 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -450,11 +450,18 @@ def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any] llvm_opt_level = 3 if options.optimization_level else 0 dump_intermediates = options.dump_intermediates or "" + gpu_func = ("gpu.module", "gpu.func") + pipeline = [ + ("water-materialize-reg-copy", {}, gpu_func), + ("water-insert-waitcnt", {}, gpu_func), + "expand-strided-metadata", "lower-affine", *add_opt(canonicalize_cse), *add_opt("loop-invariant-code-motion"), *add_opt("int-range-optimizations"), + ("water-number-registers", {}, gpu_func), + ("water-lower-memory-ops", {"chipset": target_chip}, gpu_func), "convert-scf-to-cf", ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), ("water-alloc-to-alloca", {}, "gpu.module"), From 2f58cdb61a31ee8df56bc566c96939b43d0e7721 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 03:22:56 +0100 Subject: [PATCH 093/114] fix < 32 stores Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 45 ++++++++++++-------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index bad1eddbb..715a0c75c 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -208,6 +208,10 @@ static FailureOr getBufferSuffixStore(unsigned bitWidth, if (isRDNAArch) { // RDNA uses b32, b64, etc. switch (bitWidth) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); case 32: return StringRef("b32"); case 64: @@ -400,15 +404,27 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { return success(); } +static Value extendTo32(Value value, IRRewriter &rewriter, Location loc) { + unsigned bitWidth = getBitwidth(value.getType()); + if (bitWidth >= 32) + return value; + + // Sched barrier to prevent moving the expansion before the waitcnt. + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + + Type intType = rewriter.getIntegerType(bitWidth); + if (value.getType() != intType) + value = LLVM::BitcastOp::create(rewriter, loc, intType, value); + + return arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), value); +} + /// Lower vector/scalar store to AMDGPU buffer store inline assembly template static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, bool isRDNAArch) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - if (bitWidth < 32) - return success(); - FailureOr suffix = getBufferSuffixStore(bitWidth, isRDNAArch); if (failed(suffix)) return storeOp.emitError("unsupported buffer store bit width: ") @@ -441,12 +457,12 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, arith::IntegerOverflowFlags::nsw); - Value valueToStore = storeOp.getValueToStore(); + Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, - ValueRange{valueToStore, finalOffset, bufferDesc}, asmStr, - constraints, /*hasSideEffects=*/true); + {valueToStore, finalOffset, bufferDesc}, asmStr, constraints, + /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); return success(); @@ -457,9 +473,6 @@ template static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - if (bitWidth < 32) - return success(); - FailureOr suffix = getSizeSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported store bit width: ") << bitWidth; @@ -482,11 +495,10 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress(rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - Value valueToStore = storeOp.getValueToStore(); + Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); // Create the inline assembly operation (no result for store) - createInlineAsm(rewriter, loc, TypeRange{}, ValueRange{addr, valueToStore}, - asmStr, constraints, + createInlineAsm(rewriter, loc, {}, {addr, valueToStore}, asmStr, constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); @@ -535,9 +547,6 @@ template static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - if (bitWidth < 32) - return success(); - FailureOr suffix = getSizeSuffixStore(bitWidth); if (failed(suffix)) return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; @@ -559,12 +568,12 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { Value offset = computeMemrefByteOffset<32>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - Value valueToStore = storeOp.getValueToStore(); + Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); // Create inline assembly operation (no result for store, DS uses 32-bit // addresses) - createInlineAsm(rewriter, loc, TypeRange{}, ValueRange{offset, valueToStore}, - asmStr, constraints, + createInlineAsm(rewriter, loc, {}, {offset, valueToStore}, asmStr, + constraints, /*hasSideEffects=*/true); rewriter.eraseOp(storeOp); From cc8ca84005a0f9dabea3fc4255d2e43469fd49a5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 14:34:36 +0100 Subject: [PATCH 094/114] fix shared mem addresses Signed-off-by: Ivan Butygin --- tests/kernel/wave/e2e/test_copy.py | 71 ++++++++++++++++++++ water/lib/Transforms/WaterLowerMemoryOps.cpp | 29 ++++---- 2 files changed, 86 insertions(+), 14 deletions(-) diff --git a/tests/kernel/wave/e2e/test_copy.py b/tests/kernel/wave/e2e/test_copy.py index 90b48caae..a3a4d4852 100644 --- a/tests/kernel/wave/e2e/test_copy.py +++ b/tests/kernel/wave/e2e/test_copy.py @@ -18,6 +18,7 @@ from ..common.utils import param_bool, require_e2e, use_water_backend_bool from ._test_util import get_test_shapes +from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE def get_copy_template( @@ -132,3 +133,73 @@ def test_dynamic_copy( b = device_zeros(shape, dtype=torch.float16) test(a, b) assert_close(a, b) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")) +@param_bool("use_buffer_ops", "buf_ops") +@use_water_backend_bool("use_water_backend") +@check_leaks +def test_copy_shared_memory( + shape: tuple[int, int], + use_buffer_ops: bool, + run_bench: bool, + use_water_backend: bool, +) -> None: + M = tkl.sym.M + N = tkl.sym.N + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + vector_shapes={M: BLOCK_M, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + shared = tkw.allocate((M, N), (BLOCK_M, BLOCK_N), tkl.f16, SHARED_ADDRESS_SPACE) + res = tkw.read(a) + tkw.write(res, shared) + tkw.shared_memory_barrier() + res_shared = tkw.read(shared) + tkw.write(res_shared, b) + + subs = { + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + } + + options = WaveCompileOptions( + subs=subs, + canonicalize=True, + run_bench=run_bench, + use_buffer_ops=use_buffer_ops, + use_water_backend=use_water_backend, + minimize_shared_allocs=False, # TODO: minimize_shared_allocs=True is broken + ) + options = set_default_run_config(options) + test = wave_compile(options, test) + + a = device_randn(shape, dtype=torch.float16) + b = device_zeros(shape, dtype=torch.float16) + test(a, b) + assert_close(a, b) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 715a0c75c..530e0da30 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -133,11 +133,12 @@ static Value computeMemrefByteOffset(IRRewriter &rewriter, Location loc, /// Compute the final address for a memref access with indices (for global /// operations) +template static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, Value memref, ValueRange indices, unsigned elementBitWidth) { - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto i64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), MemSpace); + auto intType = rewriter.getIntegerType(Bits); // Extract base pointer auto metadataOp = @@ -147,11 +148,11 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, // Convert base pointer to i64 Value basePtrInt = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtr); - basePtrInt = arith::IndexCastOp::create(rewriter, loc, i64Type, basePtrInt); + basePtrInt = arith::IndexCastOp::create(rewriter, loc, intType, basePtrInt); // Compute byte offset - Value byteOffsetI64 = computeMemrefByteOffset<64>(rewriter, loc, memref, - indices, elementBitWidth); + Value byteOffsetI64 = computeMemrefByteOffset(rewriter, loc, memref, + indices, elementBitWidth); // Add byte offset to base pointer Value finalAddr = @@ -393,8 +394,8 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { std::is_same_v ? cast(resultType).getElementTypeBitWidth() : bitWidth; - Value addr = computeMemrefAddress(rewriter, loc, memref, loadOp.getIndices(), - elementBitWidth); + Value addr = computeMemrefAddress<64, 0>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); // Create the inline assembly operation with result type directly auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{addr}, @@ -492,8 +493,8 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { std::is_same_v ? cast(valueType).getElementTypeBitWidth() : bitWidth; - Value addr = computeMemrefAddress(rewriter, loc, memref, storeOp.getIndices(), - elementBitWidth); + Value addr = computeMemrefAddress<64, 0>( + rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); @@ -531,7 +532,7 @@ static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { std::is_same_v ? cast(resultType).getElementTypeBitWidth() : bitWidth; - Value offset = computeMemrefByteOffset<32>( + Value offset = computeMemrefAddress<32, 3>( rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); // Create inline assembly operation (DS operations use 32-bit addresses) @@ -565,7 +566,7 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { std::is_same_v ? cast(valueType).getElementTypeBitWidth() : bitWidth; - Value offset = computeMemrefByteOffset<32>( + Value offset = computeMemrefAddress<32, 3>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); @@ -690,7 +691,7 @@ static LogicalResult lowerCopyToRegisterSpaceDS( rewriter.setInsertionPoint(copyOp); // Compute byte offset - Value offset = computeMemrefByteOffset<32>(rewriter, loc, src, /*indices=*/{}, + Value offset = computeMemrefAddress<32, 3>(rewriter, loc, src, /*indices=*/{}, elementBitWidth); // Build constraint with specific VGPR @@ -723,8 +724,8 @@ static LogicalResult lowerCopyToRegisterSpaceGlobal( rewriter.setInsertionPoint(copyOp); // Compute source address - Value addr = - computeMemrefAddress(rewriter, loc, src, /*indices=*/{}, elementBitWidth); + Value addr = computeMemrefAddress<64, 0>(rewriter, loc, src, /*indices=*/{}, + elementBitWidth); // Build constraint with specific VGPR std::string constraints = From 2ce470d33eb345c590e7dd67d9a663d7da0b871b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 15:25:56 +0100 Subject: [PATCH 095/114] refac Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 102 +++++++++---------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 530e0da30..638ed3f44 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -39,6 +39,17 @@ static unsigned getBitwidth(Type type) { return type.getIntOrFloatBitWidth(); } +static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, bool isOutput) { + std::string constraint; + if (vgprCount == 1) + constraint = "{v" + std::to_string(vgprOffset + vgprNum) + "}"; + else + constraint = "{v[" + std::to_string(vgprOffset + vgprNum) + ":" + + std::to_string(vgprOffset + vgprNum + vgprCount - 1) + "]}"; + return isOutput ? "=" + constraint : constraint; +} + /// Get the AMDGPU instruction suffix based on bit width (for loads - unsigned) static FailureOr getSizeSuffixLoad(unsigned bitWidth) { switch (bitWidth) { @@ -321,6 +332,9 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, bool isRDNAArch) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + // TODO: for bitwidths less than 32, we will need to truncate the value to 32 + // immediately after the load, breaking the calculated dependencies. + // For now, just let llvm handle the loading if (bitWidth < 32) return success(); @@ -369,9 +383,6 @@ template static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - // TODO: for bitwidths less than 32, we will need to truncate the value to 32 - // immediately after the load, breaking the calculated dependencies. - // For now, just let llvm handle the loading if (bitWidth < 32) return success(); @@ -405,6 +416,43 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { return success(); } +/// Lower vector/scalar load to AMDGPU DS load inline assembly +template +static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { + auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); + + if (bitWidth < 32) + return success(); + + FailureOr suffix = getSizeSuffixLoad(bitWidth); + if (failed(suffix)) + return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; + + Location loc = loadOp.getLoc(); + rewriter.setInsertionPoint(loadOp); + + // Build inline assembly: "ds_read_b32 $0, $1" + std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); + + // Constraints: "=v" for output (VGPR), "v" for address (VGPR) + StringRef constraints = "=v,v"; + + // Compute byte offset as i64 + unsigned elementBitWidth = + std::is_same_v + ? cast(resultType).getElementTypeBitWidth() + : bitWidth; + Value offset = computeMemrefAddress<32, 3>( + rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); + + // Create inline assembly operation (DS operations use 32-bit addresses) + auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, + asmStr, constraints, /*hasSideEffects=*/true); + + rewriter.replaceOp(loadOp, asmOp.getResult(0)); + return success(); +} + static Value extendTo32(Value value, IRRewriter &rewriter, Location loc) { unsigned bitWidth = getBitwidth(value.getType()); if (bitWidth >= 32) @@ -506,43 +554,6 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { return success(); } -/// Lower vector/scalar load to AMDGPU DS load inline assembly -template -static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { - auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); - - if (bitWidth < 32) - return success(); - - FailureOr suffix = getSizeSuffixLoad(bitWidth); - if (failed(suffix)) - return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; - - Location loc = loadOp.getLoc(); - rewriter.setInsertionPoint(loadOp); - - // Build inline assembly: "ds_read_b32 $0, $1" - std::string asmStr = ("ds_read_" + *suffix + " $0, $1").str(); - - // Constraints: "=v" for output (VGPR), "v" for address (VGPR) - StringRef constraints = "=v,v"; - - // Compute byte offset as i64 - unsigned elementBitWidth = - std::is_same_v - ? cast(resultType).getElementTypeBitWidth() - : bitWidth; - Value offset = computeMemrefAddress<32, 3>( - rewriter, loc, memref, loadOp.getIndices(), elementBitWidth); - - // Create inline assembly operation (DS operations use 32-bit addresses) - auto asmOp = createInlineAsm(rewriter, loc, resultType, ValueRange{offset}, - asmStr, constraints, /*hasSideEffects=*/true); - - rewriter.replaceOp(loadOp, asmOp.getResult(0)); - return success(); -} - /// Lower vector/scalar store to AMDGPU DS store inline assembly template static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { @@ -622,17 +633,6 @@ static bool usesRegisterSpace(Value memref) { return false; } -static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, - unsigned vgprCount, bool isOutput) { - std::string constraint; - if (vgprCount == 1) - constraint = "{v" + std::to_string(vgprOffset + vgprNum) + "}"; - else - constraint = "{v[" + std::to_string(vgprOffset + vgprNum) + ":" + - std::to_string(vgprOffset + vgprNum + vgprCount - 1) + "]}"; - return isOutput ? "=" + constraint : constraint; -} - /// Lower memref.copy when destination is in register space - buffer variant static LogicalResult lowerCopyToRegisterSpaceBuffer(memref::CopyOp copyOp, IRRewriter &rewriter, From e98f51029c4d8e93c055b2f4813b1d2d1599c15c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 16:14:44 +0100 Subject: [PATCH 096/114] refac Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 261 ++++++++++--------- 1 file changed, 137 insertions(+), 124 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 638ed3f44..bcd7610b4 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -50,8 +50,7 @@ static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, return isOutput ? "=" + constraint : constraint; } -/// Get the AMDGPU instruction suffix based on bit width (for loads - unsigned) -static FailureOr getSizeSuffixLoad(unsigned bitWidth) { +static FailureOr getLoadSizeSuffixRDNA(unsigned bitWidth) { switch (bitWidth) { case 8: return StringRef("u8"); @@ -70,8 +69,7 @@ static FailureOr getSizeSuffixLoad(unsigned bitWidth) { } } -/// Get the AMDGPU instruction suffix based on bit width (for stores) -static FailureOr getSizeSuffixStore(unsigned bitWidth) { +static FailureOr getStoreSizeSuffixRDNA(unsigned bitWidth) { switch (bitWidth) { case 8: return StringRef("b8"); @@ -90,6 +88,90 @@ static FailureOr getSizeSuffixStore(unsigned bitWidth) { } } +static FailureOr getLoadSizeSuffixCDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("ubyte"); + case 16: + return StringRef("ushort"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +static FailureOr getStoreSizeSuffixCDNA(unsigned bitWidth) { + switch (bitWidth) { + case 8: + return StringRef("byte"); + case 16: + return StringRef("short"); + case 32: + return StringRef("dword"); + case 64: + return StringRef("dwordx2"); + case 96: + return StringRef("dwordx3"); + case 128: + return StringRef("dwordx4"); + default: + return failure(); + } +} + +static FailureOr getBufferLoadSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getLoadSizeSuffixRDNA(bitWidth); + } else { + return getLoadSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getBufferStoreSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getStoreSizeSuffixRDNA(bitWidth); + } else { + return getStoreSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getGlobalLoadSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getLoadSizeSuffixRDNA(bitWidth); + } else { + return getLoadSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getGlobalStoreSuffix(unsigned bitWidth, + bool isRDNAArch) { + if (isRDNAArch) { + return getStoreSizeSuffixRDNA(bitWidth); + } else { + return getStoreSizeSuffixCDNA(bitWidth); + } +} + +static FailureOr getDSLoadSuffix(unsigned bitWidth, + bool /*isRDNAArch*/) { + return getLoadSizeSuffixRDNA(bitWidth); +} + +static FailureOr getDSStoreSuffix(unsigned bitWidth, + bool /*isRDNAArch*/) { + return getStoreSizeSuffixRDNA(bitWidth); +} + /// Create an LLVM inline assembly operation with standard attributes static LLVM::InlineAsmOp createInlineAsm(IRRewriter &rewriter, Location loc, TypeRange resultTypes, @@ -172,90 +254,6 @@ static Value computeMemrefAddress(IRRewriter &rewriter, Location loc, return LLVM::IntToPtrOp::create(rewriter, loc, ptrType, finalAddr); } -/// Get buffer instruction suffix based on bit width (for loads - unsigned) -static FailureOr getBufferSuffixLoad(unsigned bitWidth, - bool isRDNAArch) { - if (isRDNAArch) { - // RDNA uses b32, b64, etc. - switch (bitWidth) { - case 8: - return StringRef("u8"); - case 16: - return StringRef("u16"); - case 32: - return StringRef("b32"); - case 64: - return StringRef("b64"); - case 96: - return StringRef("b96"); - case 128: - return StringRef("b128"); - default: - return failure(); - } - } else { - // CDNA uses dword, dwordx2, etc. - switch (bitWidth) { - case 8: - return StringRef("ubyte"); - case 16: - return StringRef("ushort"); - case 32: - return StringRef("dword"); - case 64: - return StringRef("dwordx2"); - case 96: - return StringRef("dwordx3"); - case 128: - return StringRef("dwordx4"); - default: - return failure(); - } - } -} - -/// Get buffer instruction suffix based on bit width (for stores) -static FailureOr getBufferSuffixStore(unsigned bitWidth, - bool isRDNAArch) { - if (isRDNAArch) { - // RDNA uses b32, b64, etc. - switch (bitWidth) { - case 8: - return StringRef("b8"); - case 16: - return StringRef("b16"); - case 32: - return StringRef("b32"); - case 64: - return StringRef("b64"); - case 96: - return StringRef("b96"); - case 128: - return StringRef("b128"); - default: - return failure(); - } - } else { - // CDNA uses dword, dwordx2, etc. - switch (bitWidth) { - case 8: - return StringRef("byte"); - case 16: - return StringRef("short"); - case 32: - return StringRef("dword"); - case 64: - return StringRef("dwordx2"); - case 96: - return StringRef("dwordx3"); - case 128: - return StringRef("dwordx4"); - default: - return failure(); - } - } -} - /// Extract buffer descriptor and base offset from a fat_raw_buffer memref /// addrspace(7) format: {<4 x i32> rsrc, i32 offset} (160 bits total) /// Returns: {resource descriptor (i128), base offset (i32)} @@ -338,7 +336,7 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, if (bitWidth < 32) return success(); - FailureOr suffix = getBufferSuffixLoad(bitWidth, isRDNAArch); + FailureOr suffix = getBufferLoadSuffix(bitWidth, isRDNAArch); if (failed(suffix)) return loadOp.emitError("unsupported buffer load bit width: ") << bitWidth; @@ -380,13 +378,15 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, /// Lower vector/scalar load to LLVM inline assembly (global_load_*) template -static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { +static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { + return success(); auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); if (bitWidth < 32) return success(); - FailureOr suffix = getSizeSuffixLoad(bitWidth); + FailureOr suffix = getGlobalLoadSuffix(bitWidth, isRDNAArch); if (failed(suffix)) return loadOp.emitError("unsupported load bit width: ") << bitWidth; @@ -418,13 +418,14 @@ static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter) { /// Lower vector/scalar load to AMDGPU DS load inline assembly template -static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { +static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter, + bool isRDNAArch) { auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); if (bitWidth < 32) return success(); - FailureOr suffix = getSizeSuffixLoad(bitWidth); + FailureOr suffix = getDSLoadSuffix(bitWidth, isRDNAArch); if (failed(suffix)) return loadOp.emitError("unsupported DS load bit width: ") << bitWidth; @@ -453,10 +454,14 @@ static LogicalResult lowerLoadDS(LoadOpTy loadOp, IRRewriter &rewriter) { return success(); } -static Value extendTo32(Value value, IRRewriter &rewriter, Location loc) { +static Value extendToReg(Value value, IRRewriter &rewriter, Location loc) { unsigned bitWidth = getBitwidth(value.getType()); - if (bitWidth >= 32) + if (bitWidth >= 32) { + Type intType = rewriter.getIntegerType(bitWidth); + if (value.getType() != intType) + value = LLVM::BitcastOp::create(rewriter, loc, intType, value); return value; + } // Sched barrier to prevent moving the expansion before the waitcnt. ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); @@ -474,7 +479,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, bool isRDNAArch) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - FailureOr suffix = getBufferSuffixStore(bitWidth, isRDNAArch); + FailureOr suffix = getBufferStoreSuffix(bitWidth, isRDNAArch); if (failed(suffix)) return storeOp.emitError("unsupported buffer store bit width: ") << bitWidth; @@ -506,7 +511,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, Value finalOffset = arith::AddIOp::create(rewriter, loc, offset, baseOffset, arith::IntegerOverflowFlags::nsw); - Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); // Create inline assembly operation (no result for store) createInlineAsm(rewriter, loc, TypeRange{}, @@ -519,10 +524,11 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, /// Lower vector/scalar store to LLVM inline assembly (global_store_*) template -static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { +static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - FailureOr suffix = getSizeSuffixStore(bitWidth); + FailureOr suffix = getGlobalStoreSuffix(bitWidth, isRDNAArch); if (failed(suffix)) return storeOp.emitError("unsupported store bit width: ") << bitWidth; @@ -544,7 +550,7 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { Value addr = computeMemrefAddress<64, 0>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); // Create the inline assembly operation (no result for store) createInlineAsm(rewriter, loc, {}, {addr, valueToStore}, asmStr, constraints, @@ -556,10 +562,11 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter) { /// Lower vector/scalar store to AMDGPU DS store inline assembly template -static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { +static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter, + bool isRDNAArch) { auto [memref, valueType, bitWidth] = getStoreOpInfo(storeOp); - FailureOr suffix = getSizeSuffixStore(bitWidth); + FailureOr suffix = getDSStoreSuffix(bitWidth, isRDNAArch); if (failed(suffix)) return storeOp.emitError("unsupported DS store bit width: ") << bitWidth; @@ -580,7 +587,7 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter) { Value offset = computeMemrefAddress<32, 3>( rewriter, loc, memref, storeOp.getIndices(), elementBitWidth); - Value valueToStore = extendTo32(storeOp.getValueToStore(), rewriter, loc); + Value valueToStore = extendToReg(storeOp.getValueToStore(), rewriter, loc); // Create inline assembly operation (no result for store, DS uses 32-bit // addresses) @@ -643,7 +650,7 @@ lowerCopyToRegisterSpaceBuffer(memref::CopyOp copyOp, IRRewriter &rewriter, auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - FailureOr suffix = getBufferSuffixLoad(totalBits, isRDNAArch); + FailureOr suffix = getBufferLoadSuffix(totalBits, isRDNAArch); if (failed(suffix)) return copyOp.emitError("unsupported buffer copy bit width: ") << totalBits; @@ -676,14 +683,16 @@ lowerCopyToRegisterSpaceBuffer(memref::CopyOp copyOp, IRRewriter &rewriter, } /// Lower memref.copy when destination is in register space - DS variant -static LogicalResult lowerCopyToRegisterSpaceDS( - memref::CopyOp copyOp, IRRewriter &rewriter, unsigned vgprOffset, - unsigned vgprNum, unsigned vgprCount, unsigned totalBits, Type resultType) { +static LogicalResult +lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, + bool isRDNAArch, unsigned vgprOffset, + unsigned vgprNum, unsigned vgprCount, + unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - FailureOr suffix = getSizeSuffixLoad(totalBits); + FailureOr suffix = getDSLoadSuffix(totalBits, isRDNAArch); if (failed(suffix)) return copyOp.emitError("unsupported DS copy bit width: ") << totalBits; @@ -709,14 +718,16 @@ static LogicalResult lowerCopyToRegisterSpaceDS( } /// Lower memref.copy when destination is in register space - global variant -static LogicalResult lowerCopyToRegisterSpaceGlobal( - memref::CopyOp copyOp, IRRewriter &rewriter, unsigned vgprOffset, - unsigned vgprNum, unsigned vgprCount, unsigned totalBits, Type resultType) { +static LogicalResult +lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, + bool isRDNAArch, unsigned vgprOffset, + unsigned vgprNum, unsigned vgprCount, + unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); - FailureOr suffix = getSizeSuffixLoad(totalBits); + FailureOr suffix = getGlobalLoadSuffix(totalBits, isRDNAArch); if (failed(suffix)) return copyOp.emitError("unsupported copy bit width: ") << totalBits; @@ -787,10 +798,12 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, vgprOffset, vgprNum, vgprCount, totalBits, resultType); if (usesWorkgroupAddressSpace(src)) - return lowerCopyToRegisterSpaceDS(copyOp, rewriter, vgprOffset, vgprNum, - vgprCount, totalBits, resultType); - return lowerCopyToRegisterSpaceGlobal(copyOp, rewriter, vgprOffset, vgprNum, - vgprCount, totalBits, resultType); + return lowerCopyToRegisterSpaceDS(copyOp, rewriter, isRDNAArch, vgprOffset, + vgprNum, vgprCount, totalBits, + resultType); + return lowerCopyToRegisterSpaceGlobal(copyOp, rewriter, isRDNAArch, + vgprOffset, vgprNum, vgprCount, + totalBits, resultType); } /// Lower load from register space to inline assembly @@ -986,8 +999,8 @@ class WaterLowerMemoryOpsPass return lowerLoadFromRegisterSpace(loadOp, rewriter, vgprStart); }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, - [&]() { return lowerLoadDS(loadOp, rewriter); }, - [&]() { return lowerLoadGlobal(loadOp, rewriter); }); + [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); @@ -999,8 +1012,8 @@ class WaterLowerMemoryOpsPass return lowerStoreToRegisterSpace(storeOp, rewriter, vgprStart); }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, - [&]() { return lowerStoreDS(storeOp, rewriter); }, - [&]() { return lowerStoreGlobal(storeOp, rewriter); }); + [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); @@ -1012,8 +1025,8 @@ class WaterLowerMemoryOpsPass return lowerLoadFromRegisterSpace(loadOp, rewriter, vgprStart); }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, - [&]() { return lowerLoadDS(loadOp, rewriter); }, - [&]() { return lowerLoadGlobal(loadOp, rewriter); }); + [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, + [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); @@ -1025,8 +1038,8 @@ class WaterLowerMemoryOpsPass return lowerStoreToRegisterSpace(storeOp, rewriter, vgprStart); }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, - [&]() { return lowerStoreDS(storeOp, rewriter); }, - [&]() { return lowerStoreGlobal(storeOp, rewriter); }); + [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, + [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); if (failed(result)) return WalkResult::interrupt(); return WalkResult::advance(); From 8345c86b4c0e63a6413f7398f6564aebd5cf5eb4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 16:41:18 +0100 Subject: [PATCH 097/114] fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 1 - water/test/Transforms/lower-memory-ops.mlir | 77 +++++++++++++------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index bcd7610b4..6723d68b0 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -380,7 +380,6 @@ static LogicalResult lowerLoadBuffer(LoadOpTy loadOp, IRRewriter &rewriter, template static LogicalResult lowerLoadGlobal(LoadOpTy loadOp, IRRewriter &rewriter, bool isRDNAArch) { - return success(); auto [memref, resultType, bitWidth] = getLoadOpInfo(loadOp); if (bitWidth < 32) diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 5bea30313..61f6ad826 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -14,7 +14,8 @@ func.func @vector_load(%memref: memref<1024xf32>, %offset: index) -> vector<4xf3 // CHECK: memref.extract_aligned_pointer_as_index // CHECK: arith.index_cast // CHECK: llvm.inttoptr - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return return %result : vector<4xf32> @@ -25,7 +26,8 @@ func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector // CHECK: memref.extract_aligned_pointer_as_index // CHECK: arith.index_cast // CHECK: llvm.inttoptr - // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return return @@ -33,56 +35,64 @@ func.func @vector_store(%memref: memref<1024xf32>, %offset: index, %data: vector // CHECK-LABEL: func.func @vector_load_b32 func.func @vector_load_b32(%memref: memref<1024xf32>, %offset: index) -> vector<1xf32> { - // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<1xf32> return %result : vector<1xf32> } // CHECK-LABEL: func.func @vector_load_b64 func.func @vector_load_b64(%memref: memref<1024xf32>, %offset: index) -> vector<2xf32> { - // CHECK: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<2xf32> return %result : vector<2xf32> } // CHECK-LABEL: func.func @vector_load_b96 func.func @vector_load_b96(%memref: memref<1024xf32>, %offset: index) -> vector<3xf32> { - // CHECK: llvm.inline_asm has_side_effects "global_load_b96 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx3 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b96 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<3xf32> return %result : vector<3xf32> } // CHECK-LABEL: func.func @vector_load_b128 func.func @vector_load_b128(%memref: memref<1024xf32>, %offset: index) -> vector<4xf32> { - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %result = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> return %result : vector<4xf32> } // CHECK-LABEL: func.func @vector_store_b32 func.func @vector_store_b32(%memref: memref<1024xf32>, %offset: index, %data: vector<1xf32>) { - // CHECK: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" vector.store %data, %memref[%offset] : memref<1024xf32>, vector<1xf32> return } // CHECK-LABEL: func.func @vector_store_b64 func.func @vector_store_b64(%memref: memref<1024xf32>, %offset: index, %data: vector<2xf32>) { - // CHECK: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" vector.store %data, %memref[%offset] : memref<1024xf32>, vector<2xf32> return } // CHECK-LABEL: func.func @vector_store_b96 func.func @vector_store_b96(%memref: memref<1024xf32>, %offset: index, %data: vector<3xf32>) { - // CHECK: llvm.inline_asm has_side_effects "global_store_b96 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx3 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b96 $0, $1, off", "v,v" vector.store %data, %memref[%offset] : memref<1024xf32>, vector<3xf32> return } // CHECK-LABEL: func.func @vector_store_b128 func.func @vector_store_b128(%memref: memref<1024xf32>, %offset: index, %data: vector<4xf32>) { - // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %data, %memref[%offset] : memref<1024xf32>, vector<4xf32> return } @@ -91,10 +101,12 @@ func.func @vector_store_b128(%memref: memref<1024xf32>, %offset: index, %data: v func.func @load_store_sequence(%src: memref<1024xf32>, %dst: memref<1024xf32>, %offset: index) { // Test lowering of load/store sequence - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %data = vector.load %src[%offset] : memref<1024xf32>, vector<4xf32> - // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %data, %dst[%offset] : memref<1024xf32>, vector<4xf32> // CHECK: return @@ -171,7 +183,8 @@ func.func @buffer_store_b128(%memref: memref<1024xf32, #amdgpu.address_space, %buffer: memref<1024xf32, #amdgpu.address_space>, %offset: index) { // Load from global memory (should use global_load) - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> // Store to buffer memory (should use buffer_store) @@ -185,7 +198,8 @@ func.func @mixed_global_and_buffer(%global: memref<1024xf32>, %buffer: memref<10 %buffer_data = vector.load %buffer[%offset] : memref<1024xf32, #amdgpu.address_space>, vector<4xf32> // Store to global memory (should use global_store) - // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %buffer_data, %global[%offset] : memref<1024xf32>, vector<4xf32> return @@ -252,7 +266,8 @@ func.func @ds_store_b128(%lds: memref<1024xf32, #gpu.address_space>, // CHECK-LABEL: func.func @mixed_global_buffer_and_ds func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref<1024xf32, #amdgpu.address_space>, %lds: memref<1024xf32, #gpu.address_space>, %offset: index) { // Load from global (should use global_load) - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %global_data = vector.load %global[%offset] : memref<1024xf32>, vector<4xf32> // Store to LDS (should use ds_write) @@ -276,28 +291,32 @@ func.func @mixed_global_buffer_and_ds(%global: memref<1024xf32>, %buffer: memref // CHECK-LABEL: func.func @scalar_load_global_f32 func.func @scalar_load_global_f32(%memref: memref<1024xf32>, %offset: index) -> f32 { - // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" %result = memref.load %memref[%offset] : memref<1024xf32> return %result : f32 } // CHECK-LABEL: func.func @scalar_load_global_f64 func.func @scalar_load_global_f64(%memref: memref<1024xf64>, %offset: index) -> f64 { - // CHECK: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx2 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b64 $0, $1, off", "=v,v" %result = memref.load %memref[%offset] : memref<1024xf64> return %result : f64 } // CHECK-LABEL: func.func @scalar_store_global_f32 func.func @scalar_store_global_f32(%memref: memref<1024xf32>, %offset: index, %data: f32) { - // CHECK: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" memref.store %data, %memref[%offset] : memref<1024xf32> return } // CHECK-LABEL: func.func @scalar_store_global_f64 func.func @scalar_store_global_f64(%memref: memref<1024xf64>, %offset: index, %data: f64) { - // CHECK: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx2 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b64 $0, $1, off", "v,v" memref.store %data, %memref[%offset] : memref<1024xf64> return } @@ -335,19 +354,23 @@ func.func @scalar_store_ds_f32(%lds: memref<1024xf32, #gpu.address_space, %offset: index) { // Scalar load - // CHECK: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "=v,v" %scalar = memref.load %memref[%offset] : memref<1024xf32> // Vector load - // CHECK: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "=v,v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "=v,v" %vector = vector.load %memref[%offset] : memref<1024xf32>, vector<4xf32> // Scalar store - // CHECK: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" memref.store %scalar, %memref[%offset] : memref<1024xf32> // Vector store - // CHECK: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" vector.store %vector, %memref[%offset] : memref<1024xf32>, vector<4xf32> return @@ -362,7 +385,7 @@ func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> - // GFX9: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v255},v" // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v1023},v" memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v255}" @@ -379,7 +402,7 @@ func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> at %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> - // GFX9: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[252:255]},v" // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[1020:1023]},v" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" @@ -447,11 +470,11 @@ func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, # %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v247},v" // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v1015},v" %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" + // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[248:251]},v" // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[1016:1019]},v" %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> From 5d6d64a24cac14cfd7cca19a13abc0a4547aa429 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 17:26:24 +0100 Subject: [PATCH 098/114] update reg count Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 4 +- water/test/Transforms/lower-memory-ops.mlir | 44 ++++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 6723d68b0..13a7fe672 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -930,8 +930,8 @@ class WaterLowerMemoryOpsPass MLIRContext *ctx = &getContext(); - // Assume Wave32 for gfx12+ for now. - unsigned totalVGPRs = chip->majorVersion >= 12 ? 1024 : 256; + unsigned totalVGPRs = + chip->majorVersion >= 12 && chip->minorVersion >= 5 ? 1024 : 256; // Check if function has VGPR allocation and insert inline asm directive. auto vgprAttr = func->getAttrOfType("water.total_vgprs"); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 61f6ad826..d3257f1c7 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -380,16 +380,16 @@ func.func @mixed_scalar_and_vector(%memref: memref<1024xf32>, %offset: index) { // CHECK-LABEL: func.func @copy_global_to_reg_scalar // GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1023"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "255"]] func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes {water.total_vgprs = 1 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v255},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v1023},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v255}" - // GFX12: llvm.inline_asm "; reg_load", "={v1023}" + // GFX12: llvm.inline_asm "; reg_load", "={v255}" %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> // CHECK-NOT: memref.alloca return %val : f32 @@ -397,16 +397,16 @@ func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes { // CHECK-LABEL: func.func @copy_global_to_reg_vector // GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1020"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[252:255]},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[1020:1023]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> @@ -414,16 +414,16 @@ func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> at // CHECK-LABEL: func.func @copy_buffer_to_reg // GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1020"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "252"]] func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #amdgpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> // GFX9: llvm.inline_asm has_side_effects "buffer_load_dwordx4 $0, $1, $2, 0 offen", "={v[252:255]},v,s" - // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "={v[1020:1023]},v,s" + // GFX12: llvm.inline_asm has_side_effects "buffer_load_b128 $0, $1, $2, 0 offen", "={v[252:255]},v,s" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #amdgpu.address_space> to memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> @@ -431,16 +431,16 @@ func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space>) -> vector<4xf32> attributes {water.total_vgprs = 4 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %subview = memref.subview %arg0[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" - // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[1020:1023]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> @@ -448,15 +448,15 @@ func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space f32 attributes {water.total_vgprs = 1 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> // GFX9: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" - // GFX12: llvm.inline_asm has_side_effects "; reg_store", "={v1023},0" + // GFX12: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v255}" - // GFX12: llvm.inline_asm "; reg_load", "={v1023}" + // GFX12: llvm.inline_asm "; reg_load", "={v255}" %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> // CHECK-NOT: memref.alloca return %result : f32 @@ -464,32 +464,32 @@ func.func @store_to_reg(%val: f32) -> f32 attributes {water.total_vgprs = 1 : i3 // CHECK-LABEL: func.func @multiple_reg_allocas // GFX9-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] -// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "1015"]] +// GFX12-SAME{LITERAL}: passthrough = [["amdgpu-num-vgpr", "247"]] func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, #gpu.address_space>) -> (f32, vector<4xf32>, vector<4xf32>) attributes {water.total_vgprs = 9 : i32} { %c0 = arith.constant 0 : index %reg0 = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> %reg1 = memref.alloca() {water.vgpr_number = 1 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> %reg2 = memref.alloca() {water.vgpr_number = 5 : i32, water.vgpr_count = 4 : i32} : memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v247},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v1015},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v247},v" %sv0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>> memref.copy %sv0, %reg0 : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[248:251]},v" - // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[1016:1019]},v" + // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[248:251]},v" %sv1 = memref.subview %arg0[%c0] [4] [1] : memref<100xf32> to memref<4xf32, strided<[1], offset: ?>> memref.copy %sv1, %reg1 : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" - // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[1020:1023]},v" + // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v247}" - // GFX12: llvm.inline_asm "; reg_load", "={v1015}" + // GFX12: llvm.inline_asm "; reg_load", "={v247}" %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> // GFX9: llvm.inline_asm "; reg_load", "={v[248:251]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[1016:1019]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[248:251]}" %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[1020:1023]}" + // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> From 047db50c9fc3b2ff8a353779a9ca0130f304649a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Dec 2025 23:31:49 +0100 Subject: [PATCH 099/114] barriers Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 99 +++++++++++++++++++-- water/test/Transforms/insert-waitcnt.mlir | 26 ++++++ 2 files changed, 119 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index e1866acc9..618d8a944 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -29,6 +29,10 @@ namespace mlir::water { } // namespace mlir::water namespace { +static bool isBarrier(Operation *op) { + return isa(op) || isa(op); +} + static bool isRegisterAddressSpace(MemRefType type) { auto attr = dyn_cast_or_null(type.getMemorySpace()); return attr && attr.getInt() == 128; @@ -114,6 +118,9 @@ struct PendingOperations { if (size() >= 256) llvm::report_fatal_error("Pending operations list is too long"); + if (!ops.empty() && isBarrier(op) && isBarrier(ops.back())) + return opsTokens.back(); + ops.push_back(op); auto &back = opsTokens.emplace_back(); if (auto memref = isStoreOp(op)) @@ -166,6 +173,14 @@ struct PendingOperations { os << "]"; } + bool operator==(const PendingOperations &other) const { + return ops == other.ops && opsTokens == other.opsTokens; + } + + bool operator!=(const PendingOperations &other) const { + return !(*this == other); + } + SmallVector ops; SmallVector opsTokens; }; @@ -314,6 +329,31 @@ class WaitcntState : public AbstractDenseLattice { return changed ? ChangeResult::Change : ChangeResult::NoChange; } + ChangeResult merge(const WaitcntState &rhs) { + bool changed = false; + + if (pendingOpsLists.size() != rhs.pendingOpsLists.size()) { + changed = true; + } else { + for (auto [listSrc, listDst] : + llvm::zip(pendingOpsLists, rhs.pendingOpsLists)) { + if (*listSrc != *listDst) { + changed = true; + break; + } + } + } + + if (changed) { + pendingOpsLists = rhs.pendingOpsLists; + resetPendingOpsSet(); + } + + if (requirement.merge(rhs.requirement)) + changed = true; + return changed ? ChangeResult::Change : ChangeResult::NoChange; + } + void print(raw_ostream &os) const override { os << "WaitcntState: pending ops ["; for (auto &pendingOps : pendingOpsLists) { @@ -416,7 +456,9 @@ class WaitcntState : public AbstractDenseLattice { bool hasRequirement() const { return requirement.hasRequirement(); } /// Check if a value depends on pending operations and compute required wait - WaitcntRequirement checkRequirement(Value val) const { + WaitcntRequirement + checkSSADependency(Value val, + llvm::SmallSetVector &barriers) const { // Check if val is produced by any pending operation Operation *defOp = val.getDefiningOp(); if (!defOp) @@ -430,6 +472,8 @@ class WaitcntState : public AbstractDenseLattice { if (pendingOps->empty()) continue; + Operation *barrier = nullptr; + // Search from the back to find the most recent dependency bool found = false; auto req = WaitcntRequirement::getOperationRequirement(defOp, true); @@ -439,6 +483,9 @@ class WaitcntState : public AbstractDenseLattice { break; } + if (!barrier && isBarrier(op)) + barrier = op; + auto opReq = WaitcntRequirement::getOperationRequirement(op, false); if (!req.isSameCounterType(opReq)) continue; @@ -446,15 +493,20 @@ class WaitcntState : public AbstractDenseLattice { req = req + opReq; } - if (found) + if (found) { result.merge(req); + if (barrier) + barriers.insert(barrier); + } } return result; } /// Check for memory dependencies (RAW, WAR, WAW) and compute required wait - WaitcntRequirement checkMemoryDependency(Operation *op) const { + WaitcntRequirement + checkMemoryDependency(Operation *op, + llvm::SmallSetVector &barriers) const { auto checkMemref = [&](Value memref, bool isCurrentLoad, bool isCurrentStore) -> WaitcntRequirement { WaitcntRequirement result; @@ -465,10 +517,16 @@ class WaitcntState : public AbstractDenseLattice { if (pendingOps->empty()) continue; + Operation *barrier = nullptr; + // Search from the back to find the most recent dependency for (const auto &[pendingOp, pendingTokens] : llvm::zip(llvm::reverse(pendingOps->ops), llvm::reverse(pendingOps->opsTokens))) { + + if (!barrier && isBarrier(pendingOp)) + barrier = pendingOp; + auto checkPendingMemref = [&](Value pendingMemref, bool isPendingLoad, bool isPendingStore) -> WaitcntRequirement { @@ -500,6 +558,9 @@ class WaitcntState : public AbstractDenseLattice { } pendingResult.merge(req); } + if (pendingResult.hasRequirement() && barrier) + barriers.insert(barrier); + return pendingResult; }; if (auto loadBase = isLoadOp(pendingOp)) @@ -609,23 +670,49 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { // Start with the state before this operation WaitcntState newState = before; + if (isBarrier(op)) { + LDBG() << " Barrier: " << *op; + newState.addPendingOp(op); + LDBG() << " New state: " << newState; + propagateIfChanged(after, after->join(newState)); + return success(); + } + + llvm::SmallSetVector barriers; + // Check if any operands depend on pending operations (value dependency) WaitcntRequirement opRequirement = after->getRequirement(); for (Value operand : op->getOperands()) { - if (auto req = before.checkRequirement(operand)) { + if (auto req = before.checkSSADependency(operand, barriers)) { // Merge this requirement (take minimum for conservative wait) opRequirement.merge(req); } } // Check for memory dependencies (RAW, WAR, WAW) - if (auto memReq = before.checkMemoryDependency(op)) { + if (auto memReq = before.checkMemoryDependency(op, barriers)) { LDBG() << " Memory dependency: " << memReq; opRequirement.merge(memReq); } else { LDBG() << " No memory dependency"; } + if (opRequirement.hasRequirement() && !barriers.empty()) { + // newState.setRequirement(opRequirement); + LDBG() << " Barriers found, requirement: " << opRequirement; + for (Operation *barrier : barriers) { + LDBG() << " " << *barrier; + WaitcntState *beforeState = + getOrCreate(getProgramPointBefore(barrier)); + WaitcntState *afterState = + getOrCreate(getProgramPointAfter(barrier)); + WaitcntState newBarrierState = *beforeState; + newBarrierState.setRequirement(opRequirement); + propagateIfChanged(afterState, afterState->merge(newBarrierState)); + } + return success(); + } + // Check if this is an existing memory_counter_wait operation if (auto waitOp = dyn_cast(op)) { LDBG() << " Existing waitcnt operation: " << *waitOp; @@ -649,7 +736,7 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { } LDBG() << " New state: " << newState; - propagateIfChanged(after, after->join(newState)); + propagateIfChanged(after, after->merge(newState)); return success(); } diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 84f645269..4af978d1f 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -31,6 +31,32 @@ func.func @two_loads_use_in_reverse_order(%memrefA: memref<1024xf32>, %memrefB: return %addB : vector<4xf32> } +// CHECK-LABEL: func.func @lds_barriers +// CHECK-SAME: (%[[ARG0:.*]]: memref<1024xf32>, %[[ARG1:.*]]: memref<1024xf32>, %{{.*}}: index) +func.func @lds_barriers(%memrefA: memref<1024xf32>, %memrefB: memref<1024xf32>, %offset: index) -> vector<4xf32> { + // CHECK: %[[LOAD_A:.*]] = vector.load %[[ARG0]] + // CHECK: %[[LOAD_B:.*]] = vector.load %[[ARG1]] + %loadA = vector.load %memrefA[%offset] : memref<1024xf32>, vector<4xf32> + %loadB = vector.load %memrefB[%offset] : memref<1024xf32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: amdgpu.lds_barrier + // CHECK-NEXT: %[[ADD_A:.*]] = arith.addf %[[LOAD_A]], %[[LOAD_A]] + amdgpu.lds_barrier + %addA = arith.addf %loadA, %loadA : vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: amdgpu.lds_barrier + // CHECK-NEXT: %[[ADD_B:.*]] = arith.addf %[[LOAD_B]], %[[ADD_A]] + amdgpu.lds_barrier + %addB = arith.addf %loadB, %addA : vector<4xf32> + + // CHECK-NOT: amdgpu.memory_counter_wait + + // CHECK: return %[[ADD_B]] + return %addB : vector<4xf32> +} + // CHECK-LABEL: func.func @raw_dependency // CHECK-SAME: (%[[MEM:.*]]: memref<1024xf32>, %[[DATA:.*]]: vector<4xf32>, %{{.*}}: index) func.func @raw_dependency(%memref: memref<1024xf32>, %data: vector<4xf32>, %offset: index) -> vector<4xf32> { From d2cfbe5bc193ec0607094812ffb176dedcb790e0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Dec 2025 00:18:52 +0100 Subject: [PATCH 100/114] more barriers Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 13a7fe672..68386af8b 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -847,6 +847,8 @@ static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, if (bitWidth < 32) asmType = rewriter.getIntegerType(32); + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + Value asmResult = createInlineAsm(rewriter, loc, asmType, {}, asmStr, constraints, /*hasSideEffects=*/false) .getResult(0); From 02e26f3a301f59d0e22a04677b77989ed8c757b1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Dec 2025 00:48:51 +0100 Subject: [PATCH 101/114] nicer code Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 80 ++++++++------------ 1 file changed, 32 insertions(+), 48 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 68386af8b..ce950371e 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -640,11 +640,11 @@ static bool usesRegisterSpace(Value memref) { } /// Lower memref.copy when destination is in register space - buffer variant -static LogicalResult -lowerCopyToRegisterSpaceBuffer(memref::CopyOp copyOp, IRRewriter &rewriter, - bool isRDNAArch, unsigned vgprOffset, - unsigned vgprNum, unsigned vgprCount, - unsigned totalBits, Type resultType) { +static LogicalResult lowerCopyToRegBuffer(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, + unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); @@ -682,11 +682,11 @@ lowerCopyToRegisterSpaceBuffer(memref::CopyOp copyOp, IRRewriter &rewriter, } /// Lower memref.copy when destination is in register space - DS variant -static LogicalResult -lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, - bool isRDNAArch, unsigned vgprOffset, - unsigned vgprNum, unsigned vgprCount, - unsigned totalBits, Type resultType) { +static LogicalResult lowerCopyToRegDS(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, unsigned totalBits, + Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); @@ -717,11 +717,11 @@ lowerCopyToRegisterSpaceDS(memref::CopyOp copyOp, IRRewriter &rewriter, } /// Lower memref.copy when destination is in register space - global variant -static LogicalResult -lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, - bool isRDNAArch, unsigned vgprOffset, - unsigned vgprNum, unsigned vgprCount, - unsigned totalBits, Type resultType) { +static LogicalResult lowerCopyToRegGlobal(memref::CopyOp copyOp, + IRRewriter &rewriter, bool isRDNAArch, + unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount, + unsigned totalBits, Type resultType) { Value src = copyOp.getSource(); auto srcType = cast(src.getType()); unsigned elementBitWidth = srcType.getElementTypeBitWidth(); @@ -752,10 +752,8 @@ lowerCopyToRegisterSpaceGlobal(memref::CopyOp copyOp, IRRewriter &rewriter, } /// Lower memref.copy when destination is in register space -static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, - IRRewriter &rewriter, - bool isRDNAArch, - unsigned vgprOffset) { +static LogicalResult lowerCopyToReg(memref::CopyOp copyOp, IRRewriter &rewriter, + bool isRDNAArch, unsigned vgprOffset) { Value src = copyOp.getSource(); Value dst = copyOp.getTarget(); @@ -793,23 +791,19 @@ static LogicalResult lowerCopyToRegisterSpace(memref::CopyOp copyOp, // Dispatch based on source memory space if (usesBufferAddressSpace(src)) - return lowerCopyToRegisterSpaceBuffer(copyOp, rewriter, isRDNAArch, - vgprOffset, vgprNum, vgprCount, - totalBits, resultType); + return lowerCopyToRegBuffer(copyOp, rewriter, isRDNAArch, vgprOffset, + vgprNum, vgprCount, totalBits, resultType); if (usesWorkgroupAddressSpace(src)) - return lowerCopyToRegisterSpaceDS(copyOp, rewriter, isRDNAArch, vgprOffset, - vgprNum, vgprCount, totalBits, - resultType); - return lowerCopyToRegisterSpaceGlobal(copyOp, rewriter, isRDNAArch, - vgprOffset, vgprNum, vgprCount, - totalBits, resultType); + return lowerCopyToRegDS(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); + return lowerCopyToRegGlobal(copyOp, rewriter, isRDNAArch, vgprOffset, vgprNum, + vgprCount, totalBits, resultType); } /// Lower load from register space to inline assembly template -static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, - IRRewriter &rewriter, - unsigned vgprOffset) { +static LogicalResult lowerLoadFromReg(LoadOpTy loadOp, IRRewriter &rewriter, + unsigned vgprOffset) { Value memref; if constexpr (std::is_same_v) memref = loadOp.getBase(); @@ -865,9 +859,8 @@ static LogicalResult lowerLoadFromRegisterSpace(LoadOpTy loadOp, /// Lower store to register space to inline assembly template -static LogicalResult lowerStoreToRegisterSpace(StoreOpTy storeOp, - IRRewriter &rewriter, - unsigned vgprOffset) { +static LogicalResult lowerStoreToReg(StoreOpTy storeOp, IRRewriter &rewriter, + unsigned vgprOffset) { Value memref; if constexpr (std::is_same_v) memref = storeOp.getBase(); @@ -996,9 +989,7 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getBase(), - [&]() { - return lowerLoadFromRegisterSpace(loadOp, rewriter, vgprStart); - }, + [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); @@ -1009,9 +1000,7 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getBase(), - [&]() { - return lowerStoreToRegisterSpace(storeOp, rewriter, vgprStart); - }, + [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); @@ -1022,9 +1011,7 @@ class WaterLowerMemoryOpsPass if (auto loadOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( loadOp.getMemRef(), - [&]() { - return lowerLoadFromRegisterSpace(loadOp, rewriter, vgprStart); - }, + [&]() { return lowerLoadFromReg(loadOp, rewriter, vgprStart); }, [&]() { return lowerLoadBuffer(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadDS(loadOp, rewriter, isRDNAArch); }, [&]() { return lowerLoadGlobal(loadOp, rewriter, isRDNAArch); }); @@ -1035,9 +1022,7 @@ class WaterLowerMemoryOpsPass if (auto storeOp = dyn_cast(op)) { LogicalResult result = lowerMemoryOp( storeOp.getMemRef(), - [&]() { - return lowerStoreToRegisterSpace(storeOp, rewriter, vgprStart); - }, + [&]() { return lowerStoreToReg(storeOp, rewriter, vgprStart); }, [&]() { return lowerStoreBuffer(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreDS(storeOp, rewriter, isRDNAArch); }, [&]() { return lowerStoreGlobal(storeOp, rewriter, isRDNAArch); }); @@ -1048,8 +1033,7 @@ class WaterLowerMemoryOpsPass if (auto copyOp = dyn_cast(op)) { // Only lower copy if destination is in register space if (usesRegisterSpace(copyOp.getTarget())) { - if (failed(lowerCopyToRegisterSpace(copyOp, rewriter, isRDNAArch, - vgprStart))) + if (failed(lowerCopyToReg(copyOp, rewriter, isRDNAArch, vgprStart))) return WalkResult::interrupt(); return WalkResult::advance(); } From 09d6dc1d7c468af884b0d67c66b0924bfdaac984 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Dec 2025 00:56:29 +0100 Subject: [PATCH 102/114] code cleanup Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 618d8a944..0e38a972e 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -141,6 +141,12 @@ struct PendingOperations { return llvm::zip(ops, opsTokens); } + auto opsAndTokensReverse() const { + assert(ops.size() == opsTokens.size() && + "ops and opsTokens must have the same size"); + return llvm::zip(llvm::reverse(ops), llvm::reverse(opsTokens)); + } + bool hasSameTail(const PendingOperations &other) const { for (const auto &[op1, op2, tok1, tok2] : llvm::zip(llvm::reverse(ops), llvm::reverse(other.ops), @@ -521,17 +527,18 @@ class WaitcntState : public AbstractDenseLattice { // Search from the back to find the most recent dependency for (const auto &[pendingOp, pendingTokens] : - llvm::zip(llvm::reverse(pendingOps->ops), - llvm::reverse(pendingOps->opsTokens))) { + pendingOps->opsAndTokensReverse()) { if (!barrier && isBarrier(pendingOp)) barrier = pendingOp; + // We canot capture structured bindings into lambda, thanks C++ + auto &pendingTok = pendingTokens; auto checkPendingMemref = [&](Value pendingMemref, bool isPendingLoad, bool isPendingStore) -> WaitcntRequirement { WaitcntRequirement pendingResult; - if (!mayAlias(memref, pendingMemref, pendingTokens)) + if (!mayAlias(memref, pendingMemref, pendingTok)) return pendingResult; // Check for dependencies: From cafa81eda858e58663e0b1d693c6c2a27e932bc0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Dec 2025 01:02:58 +0100 Subject: [PATCH 103/114] more code cleanup Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 0e38a972e..52de06e06 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -526,14 +526,15 @@ class WaitcntState : public AbstractDenseLattice { Operation *barrier = nullptr; // Search from the back to find the most recent dependency - for (const auto &[pendingOp, pendingTokens] : + for (const auto &[pendingOpVar, pendingTokensVar] : pendingOps->opsAndTokensReverse()) { - if (!barrier && isBarrier(pendingOp)) + if (!barrier && isBarrier(pendingOpVar)) barrier = pendingOp; - // We canot capture structured bindings into lambda, thanks C++ - auto &pendingTok = pendingTokens; + // We canot capture structured bindings into lambda, thanks C++. + auto &pendingTokens = pendingTokensVar; + auto &pendingOp = pendingOpVar; auto checkPendingMemref = [&](Value pendingMemref, bool isPendingLoad, bool isPendingStore) -> WaitcntRequirement { From 740ec70aab28f3caaceb45327da40011efc96165 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Dec 2025 01:09:20 +0100 Subject: [PATCH 104/114] fixes Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 52de06e06..8b77742a7 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -530,7 +530,7 @@ class WaitcntState : public AbstractDenseLattice { pendingOps->opsAndTokensReverse()) { if (!barrier && isBarrier(pendingOpVar)) - barrier = pendingOp; + barrier = pendingOpVar; // We canot capture structured bindings into lambda, thanks C++. auto &pendingTokens = pendingTokensVar; @@ -539,7 +539,7 @@ class WaitcntState : public AbstractDenseLattice { [&](Value pendingMemref, bool isPendingLoad, bool isPendingStore) -> WaitcntRequirement { WaitcntRequirement pendingResult; - if (!mayAlias(memref, pendingMemref, pendingTok)) + if (!mayAlias(memref, pendingMemref, pendingTokens)) return pendingResult; // Check for dependencies: From c45fc966a46d8552048aa01a4799aa90e846fe2c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Dec 2025 01:46:11 +0100 Subject: [PATCH 105/114] code clenaup Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 8b77742a7..2131224e3 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -828,10 +828,8 @@ class WaterInsertWaitcntPass loadBaselineAnalyses(solver); solver.load(); - if (failed(solver.initializeAndRun(op))) { - signalPassFailure(); - return; - } + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); // Insert waitcnt operations based on analysis results IRRewriter rewriter(&getContext()); From 666e989125521e9a458005d8fec9b3c37fc67286 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 18 Dec 2025 12:51:04 +0100 Subject: [PATCH 106/114] refac Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 22 +++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index ce950371e..ca8611b7b 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -39,15 +39,23 @@ static unsigned getBitwidth(Type type) { return type.getIntOrFloatBitWidth(); } +static std::string getVGPRRange(unsigned vgprOffset, unsigned vgprNum, + unsigned vgprCount) { + assert(vgprCount > 0 && "VGPR count must be greater than 0"); + unsigned start = vgprOffset + vgprNum; + if (vgprCount == 1) { + return ("v" + llvm::Twine(start)).str(); + } else { + unsigned end = start + vgprCount - 1; + return ("v[" + llvm::Twine(start) + ":" + llvm::Twine(end) + "]").str(); + } +} + static std::string getVGPRConstraint(unsigned vgprOffset, unsigned vgprNum, unsigned vgprCount, bool isOutput) { - std::string constraint; - if (vgprCount == 1) - constraint = "{v" + std::to_string(vgprOffset + vgprNum) + "}"; - else - constraint = "{v[" + std::to_string(vgprOffset + vgprNum) + ":" + - std::to_string(vgprOffset + vgprNum + vgprCount - 1) + "]}"; - return isOutput ? "=" + constraint : constraint; + return (llvm::Twine(isOutput ? "=" : "") + "{" + + getVGPRRange(vgprOffset, vgprNum, vgprCount) + "}") + .str(); } static FailureOr getLoadSizeSuffixRDNA(unsigned bitWidth) { From eef3d64f5a81ba25a913de2e8e9a74d8e1e3f820 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 18 Dec 2025 19:06:24 +0100 Subject: [PATCH 107/114] include vgp range in comment Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 6 ++-- water/test/Transforms/lower-memory-ops.mlir | 36 ++++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index ca8611b7b..05639faa4 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -841,7 +841,8 @@ static LogicalResult lowerLoadFromReg(LoadOpTy loadOp, IRRewriter &rewriter, getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true); // Simple v_mov to read from VGPR (compiler will optimize this away) - std::string asmStr = "; reg_load"; + std::string asmStr = + "; reg_load " + getVGPRRange(vgprOffset, vgprNum, vgprCount); Type resultType = loadOp.getResult().getType(); Type asmType = resultType; @@ -898,7 +899,8 @@ static LogicalResult lowerStoreToReg(StoreOpTy storeOp, IRRewriter &rewriter, getVGPRConstraint(vgprOffset, vgprNum, vgprCount, true) + ",0"; // v_mov to write to VGPR (input constraint 0 ties to output) - std::string asmStr = "; reg_store"; + std::string asmStr = + "; reg_store " + getVGPRRange(vgprOffset, vgprNum, vgprCount); Value valueToStore = storeOp.getValueToStore(); unsigned bitWidth = getBitwidth(valueToStore.getType()); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index d3257f1c7..16945bc24 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -388,8 +388,8 @@ func.func @copy_global_to_reg_scalar(%arg0: memref<100xf32>) -> f32 attributes { // GFX9: llvm.inline_asm has_side_effects "global_load_dword $0, $1, off", "={v255},v" // GFX12: llvm.inline_asm has_side_effects "global_load_b32 $0, $1, off", "={v255},v" memref.copy %subview, %reg : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v255}" - // GFX12: llvm.inline_asm "; reg_load", "={v255}" + // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" + // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" %val = memref.load %reg[%c0] : memref<1xf32, 128 : i32> // CHECK-NOT: memref.alloca return %val : f32 @@ -405,8 +405,8 @@ func.func @copy_global_to_reg_vector(%arg0: memref<100xf32>) -> vector<4xf32> at // GFX9: llvm.inline_asm has_side_effects "global_load_dwordx4 $0, $1, off", "={v[252:255]},v" // GFX12: llvm.inline_asm has_side_effects "global_load_b128 $0, $1, off", "={v[252:255]},v" memref.copy %subview, %reg : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> @@ -422,8 +422,8 @@ func.func @copy_buffer_to_reg(%arg0: memref<100xf32, #amdgpu.address_space, #amdgpu.address_space> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> @@ -439,8 +439,8 @@ func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space, #gpu.address_space> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" %val = vector.load %reg[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val : vector<4xf32> @@ -452,11 +452,11 @@ func.func @copy_workgroup_to_reg(%arg0: memref<100xf32, #gpu.address_space f32 attributes {water.total_vgprs = 1 : i32} { %c0 = arith.constant 0 : index %reg = memref.alloca() {water.vgpr_number = 0 : i32, water.vgpr_count = 1 : i32} : memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" - // GFX12: llvm.inline_asm has_side_effects "; reg_store", "={v255},0" + // GFX9: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" + // GFX12: llvm.inline_asm has_side_effects "; reg_store v255", "={v255},0" memref.store %val, %reg[%c0] : memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v255}" - // GFX12: llvm.inline_asm "; reg_load", "={v255}" + // GFX9: llvm.inline_asm "; reg_load v255", "={v255}" + // GFX12: llvm.inline_asm "; reg_load v255", "={v255}" %result = memref.load %reg[%c0] : memref<1xf32, 128 : i32> // CHECK-NOT: memref.alloca return %result : f32 @@ -482,14 +482,14 @@ func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, # // GFX12: llvm.inline_asm has_side_effects "ds_read_b128 $0, $1", "={v[252:255]},v" %sv2 = memref.subview %arg1[%c0] [4] [1] : memref<100xf32, #gpu.address_space> to memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> memref.copy %sv2, %reg2 : memref<4xf32, strided<[1], offset: ?>, #gpu.address_space> to memref<4xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v247}" - // GFX12: llvm.inline_asm "; reg_load", "={v247}" + // GFX9: llvm.inline_asm "; reg_load v247", "={v247}" + // GFX12: llvm.inline_asm "; reg_load v247", "={v247}" %val0 = memref.load %reg0[%c0] : memref<1xf32, 128 : i32> - // GFX9: llvm.inline_asm "; reg_load", "={v[248:251]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[248:251]}" + // GFX9: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" + // GFX12: llvm.inline_asm "; reg_load v[248:251]", "={v[248:251]}" %val1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> - // GFX9: llvm.inline_asm "; reg_load", "={v[252:255]}" - // GFX12: llvm.inline_asm "; reg_load", "={v[252:255]}" + // GFX9: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" + // GFX12: llvm.inline_asm "; reg_load v[252:255]", "={v[252:255]}" %val2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> // CHECK-NOT: memref.alloca return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> From 7811c1cd2614088deb9d13e656b0e47cbee3f129 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 18 Dec 2025 20:55:50 +0100 Subject: [PATCH 108/114] fix waitcnt queue Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterInsertWaitcnt.cpp | 16 ++++++--- water/test/Transforms/insert-waitcnt.mlir | 38 +++++++++++++++++++-- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/water/lib/Transforms/WaterInsertWaitcnt.cpp b/water/lib/Transforms/WaterInsertWaitcnt.cpp index 2131224e3..8e9eb5fbb 100644 --- a/water/lib/Transforms/WaterInsertWaitcnt.cpp +++ b/water/lib/Transforms/WaterInsertWaitcnt.cpp @@ -264,8 +264,11 @@ struct WaitcntRequirement { } bool operator>(const WaitcntRequirement &other) const { - return load_cnt.value_or(0) > other.load_cnt.value_or(0) || - ds_cnt.value_or(0) > other.ds_cnt.value_or(0); + if (load_cnt && other.load_cnt && *load_cnt > *other.load_cnt) + return true; + if (ds_cnt && other.ds_cnt && *ds_cnt > *other.ds_cnt) + return true; + return false; } operator bool() const { return hasRequirement(); } @@ -743,8 +746,13 @@ class WaitcntAnalysis : public DenseForwardDataFlowAnalysis { newState.addPendingOp(op); } - LDBG() << " New state: " << newState; - propagateIfChanged(after, after->merge(newState)); + auto changed = after->merge(newState); + if (changed == ChangeResult::Change) { + LDBG() << " New state: " << newState; + } else { + LDBG() << " No change"; + } + propagateIfChanged(after, changed); return success(); } diff --git a/water/test/Transforms/insert-waitcnt.mlir b/water/test/Transforms/insert-waitcnt.mlir index 4af978d1f..a8f82fce8 100644 --- a/water/test/Transforms/insert-waitcnt.mlir +++ b/water/test/Transforms/insert-waitcnt.mlir @@ -483,7 +483,7 @@ func.func @double_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %ste %data = vector.load %current[%offset] : memref<1024xf32, #gpu.address_space>, vector<4xf32> // Cannot skip unfortunately - // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK: amdgpu.memory_counter_wait load(0) ds(0) // CHECK: vector.store vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> @@ -524,7 +524,7 @@ func.func @triple_buffering(%src: memref<1024xf32>, %lb: index, %ub: index, %ste memref.copy %src, %next_next : memref<1024xf32> to memref<1024xf32, #gpu.address_space> // Skip the prev copy - // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK: amdgpu.memory_counter_wait load(0) ds(1) // CHECK: vector.store vector.store %data, %out[%offset] : memref<1024xf32>, vector<4xf32> @@ -588,3 +588,37 @@ func.func @triple_buffering_reg_space(%src: memref<1024xf32>, %lb: index, %ub: i // CHECK: return return } + +// CHECK-LABEL: func.func @load_store_repeated +func.func @load_store_repeated(%src0: memref<4xf32>, %src1: memref<4xf32>, %offset: index) { + %c0 = arith.constant 0 : index + %buff0 = memref.alloc() : memref<4xf32, #gpu.address_space> + %buff1 = memref.alloc() : memref<4xf32, #gpu.address_space> + %reg0 = memref.alloca() : memref<4xf32, 128 : i32> + %reg1 = memref.alloca() : memref<4xf32, 128 : i32> + %reg2 = memref.alloca() : memref<4xf32, 128 : i32> + %reg3 = memref.alloca() : memref<4xf32, 128 : i32> + + // CHECK-COUNT-4: memref.copy + memref.copy %src0, %reg0 : memref<4xf32> to memref<4xf32, 128 : i32> + memref.copy %src1, %reg1 : memref<4xf32> to memref<4xf32, 128 : i32> + + memref.copy %buff0, %reg2 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> + memref.copy %buff1, %reg3 : memref<4xf32, #gpu.address_space> to memref<4xf32, 128 : i32> + + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK-NEXT: vector.load + %data0 = vector.load %reg0[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait load(0) + // CHECK-NEXT: vector.load + %data1 = vector.load %reg1[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + // CHECK: amdgpu.memory_counter_wait ds(1) + // CHECK-NEXT: vector.load + %data2 = vector.load %reg2[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + // CHECK: amdgpu.memory_counter_wait ds(0) + // CHECK-NEXT: vector.load + %data3 = vector.load %reg3[%c0] : memref<4xf32, 128 : i32>, vector<4xf32> + + return +} From 1aeb0acfb7f69664c33bf07b974b2f79757297a3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Dec 2025 00:24:52 +0100 Subject: [PATCH 109/114] dumb mfma hazard mitigation Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterLowerMemoryOps.cpp | 50 +++++++++++++++- water/test/Transforms/lower-memory-ops.mlir | 61 ++++++++++++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/water/lib/Transforms/WaterLowerMemoryOps.cpp b/water/lib/Transforms/WaterLowerMemoryOps.cpp index 05639faa4..3fd4e6ad6 100644 --- a/water/lib/Transforms/WaterLowerMemoryOps.cpp +++ b/water/lib/Transforms/WaterLowerMemoryOps.cpp @@ -199,6 +199,50 @@ static bool isRDNA(const amdgpu::Chipset &chipset) { return chipset.majorVersion != 9; } +static Operation *propagateExtract(Operation *op) { + if (auto extract = dyn_cast(op)) + return extract.getSource().getDefiningOp(); + if (auto extract = dyn_cast(op)) + return extract.getSource().getDefiningOp(); + return nullptr; +} + +static unsigned checkHazards(Operation *currentOp, Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return 0; + + while (auto nextOp = propagateExtract(op)) + op = nextOp; + + if (op->getBlock() != currentOp->getBlock()) + return 0; + + if (!isa(op)) + return 0; + + while (op != currentOp) { + if (isa(op) && + cast(op).getIntrin() == "llvm.amdgcn.s.nop") + return 0; + op = op->getNextNode(); + } + + return 5; // HACK for now +} + +static void handleHazards(IRRewriter &rewriter, Location loc, Operation *op, + Value value) { + unsigned hazard = checkHazards(op, value); + if (hazard > 0) { + ROCDL::SchedBarrier::create(rewriter, loc, {}, 0); + Value nopCount = + arith::ConstantIntOp::create(rewriter, loc, hazard - 1, 16); + StringAttr intrin = rewriter.getStringAttr("llvm.amdgcn.s.nop"); + LLVM::CallIntrinsicOp::create(rewriter, loc, {}, intrin, nopCount); + } +} + /// Compute byte offset as iX for a memref access with indices template static Value computeMemrefByteOffset(IRRewriter &rewriter, Location loc, @@ -493,6 +537,7 @@ static LogicalResult lowerStoreBuffer(StoreOpTy storeOp, IRRewriter &rewriter, Location loc = storeOp.getLoc(); rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); // Build inline assembly: "buffer_store_ $0, $1, $2, 0 offen" std::string asmStr = @@ -540,6 +585,8 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter, return storeOp.emitError("unsupported store bit width: ") << bitWidth; Location loc = storeOp.getLoc(); + rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); // Build the inline assembly string: "global_store_b64 $0, $1, off" std::string asmStr = ("global_store_" + *suffix + " $0, $1, off").str(); @@ -547,8 +594,6 @@ static LogicalResult lowerStoreGlobal(StoreOpTy storeOp, IRRewriter &rewriter, // Constraints: "v" for address (VGPR), "v" for data (VGPR) StringRef constraints = "v,v"; - rewriter.setInsertionPoint(storeOp); - // Compute the final address unsigned elementBitWidth = std::is_same_v @@ -579,6 +624,7 @@ static LogicalResult lowerStoreDS(StoreOpTy storeOp, IRRewriter &rewriter, Location loc = storeOp.getLoc(); rewriter.setInsertionPoint(storeOp); + handleHazards(rewriter, loc, storeOp, storeOp.getValueToStore()); // Build inline assembly: "ds_write_b32 $0, $1" std::string asmStr = ("ds_write_" + *suffix + " $0, $1").str(); diff --git a/water/test/Transforms/lower-memory-ops.mlir b/water/test/Transforms/lower-memory-ops.mlir index 16945bc24..0458bc81e 100644 --- a/water/test/Transforms/lower-memory-ops.mlir +++ b/water/test/Transforms/lower-memory-ops.mlir @@ -494,3 +494,64 @@ func.func @multiple_reg_allocas(%arg0: memref<100xf32>, %arg1: memref<100xf32, # // CHECK-NOT: memref.alloca return %val0, %val1, %val2 : f32, vector<4xf32>, vector<4xf32> } + +// ----- +// Test MFMA hazard handling with s_nop insertion + +// CHECK-LABEL: func.func @mfma_hazard_store +func.func @mfma_hazard_store(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + // Perform MFMA operation + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // Store MFMA result - should trigger hazard handling + // CHECK: rocdl.sched.barrier + // CHECK: arith.constant 4 : i16 + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> + + return +} + +// CHECK-LABEL: func.func @mfma_hazard_with_extract +func.func @mfma_hazard_with_extract(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + // MFMA with vector extract - hazard checking should propagate through extract + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + %extracted = vector.extract %result[0] : f32 from vector<4xf32> + + // Store extracted value - should still detect hazard through propagation + // CHECK: rocdl.sched.barrier + // CHECK: arith.constant 4 : i16 + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // GFX9: llvm.inline_asm has_side_effects "global_store_dword $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b32 $0, $1, off", "v,v" + memref.store %extracted, %arg0[%offset] : memref<1024xf32> + + return +} + +// CHECK-LABEL: func.func @no_hazard_with_existing_nop +func.func @no_hazard_with_existing_nop(%arg0: memref<1024xf32>, %a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>) { + %offset = arith.constant 0 : index + + %result = amdgpu.mfma 16x16x16 %a * %b + %c blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // Manually insert s.nop + %nop_count = arith.constant 4 : i16 + llvm.call_intrinsic "llvm.amdgcn.s.nop"(%nop_count) : (i16) -> () + + // Store should NOT insert another s.nop since one already exists + // CHECK: llvm.call_intrinsic "llvm.amdgcn.s.nop" + // CHECK-NOT: rocdl.sched.barrier + // GFX9: llvm.inline_asm has_side_effects "global_store_dwordx4 $0, $1, off", "v,v" + // GFX12: llvm.inline_asm has_side_effects "global_store_b128 $0, $1, off", "v,v" + vector.store %result, %arg0[%offset] : memref<1024xf32>, vector<4xf32> + + return +} From 82455a5b832030e5d65da0bc3145291f3acef1b8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Dec 2025 11:48:56 +0100 Subject: [PATCH 110/114] mxfp test Signed-off-by: Ivan Butygin --- tests/kernel/wave/common/utils.py | 63 ++++++++ tests/kernel/wave/wave_gemm_mxfp_test.py | 192 ++++++++++++++++++++--- 2 files changed, 230 insertions(+), 25 deletions(-) diff --git a/tests/kernel/wave/common/utils.py b/tests/kernel/wave/common/utils.py index 916e25344..744eaa518 100644 --- a/tests/kernel/wave/common/utils.py +++ b/tests/kernel/wave/common/utils.py @@ -7,6 +7,7 @@ import pytest from wave_lang.kernel.wave.utils.run_utils import get_default_arch from pathlib import Path +from dataclasses import dataclass, field require_e2e = pytest.mark.require_e2e expensive_test = pytest.mark.expensive_test @@ -91,3 +92,65 @@ def use_water_backend_bool(name: str): def glob_asm_files(path: Path) -> list[Path]: return list(filter(lambda x: x.suffix in [".s", ".rocmasm"], path.glob("*"))) + + +@dataclass +class KernelMetadata: + """Metadata extracted from kernel assembly.""" + + vgpr_count: int | None = None + vgpr_spill_count: int | None = None + sgpr_count: int | None = None + sgpr_spill_count: int | None = None + waitcnt_ops: list[str] = field(default_factory=list) + + +def extract_kernel_metadata(asm_text: str) -> KernelMetadata: + """ + Extract kernel metadata from ROCm assembly text. + + Args: + asm_text: Assembly text content (e.g., from .rocmasm file) + + Returns: + KernelMetadata containing: + - vgpr_count: Number of VGPRs allocated + - vgpr_spill_count: Number of VGPRs spilled + - sgpr_count: Number of SGPRs allocated + - sgpr_spill_count: Number of SGPRs spilled + - waitcnt_ops: List of all waitcnt operations found in the assembly + """ + import re + + metadata = KernelMetadata() + + # Extract from YAML metadata section (more reliable) + # Look for patterns like: + # .vgpr_count: 3 + # .vgpr_spill_count: 0 + # .sgpr_count: 8 + # .sgpr_spill_count: 0 + + vgpr_count_match = re.search(r"\.vgpr_count:\s+(\d+)", asm_text) + if vgpr_count_match: + metadata.vgpr_count = int(vgpr_count_match.group(1)) + + vgpr_spill_match = re.search(r"\.vgpr_spill_count:\s+(\d+)", asm_text) + if vgpr_spill_match: + metadata.vgpr_spill_count = int(vgpr_spill_match.group(1)) + + sgpr_count_match = re.search(r"\.sgpr_count:\s+(\d+)", asm_text) + if sgpr_count_match: + metadata.sgpr_count = int(sgpr_count_match.group(1)) + + sgpr_spill_match = re.search(r"\.sgpr_spill_count:\s+(\d+)", asm_text) + if sgpr_spill_match: + metadata.sgpr_spill_count = int(sgpr_spill_match.group(1)) + + # Extract all waitcnt operations + # Pattern: s_waitcnt followed by any arguments + # Examples: s_waitcnt lgkmcnt(0), s_waitcnt vmcnt(0), etc. + waitcnt_pattern = re.compile(r"s_waitcnt\s+[^\n]+") + metadata.waitcnt_ops = waitcnt_pattern.findall(asm_text) + + return metadata diff --git a/tests/kernel/wave/wave_gemm_mxfp_test.py b/tests/kernel/wave/wave_gemm_mxfp_test.py index a06a79016..4f5a70a12 100644 --- a/tests/kernel/wave/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave/wave_gemm_mxfp_test.py @@ -1,5 +1,6 @@ import torch import pytest +from pathlib import Path import wave_lang.kernel.lang as tkl import wave_lang.kernel.wave as tkw @@ -25,7 +26,14 @@ ScaledMMAType, ) -from .common.utils import param_bool, require_e2e, require_cdna4 +from .common.utils import ( + extract_kernel_metadata, + glob_asm_files, + param_bool, + require_cdna4, + require_e2e, + use_water_backend_bool, +) # Note this is specified by the HW and cannot be changed. SCALE_GROUP_SIZE = 32 @@ -230,32 +238,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # BMK @ NK -> BMN represents Linear Layer style BMM. -@require_e2e -@require_cdna4 -@pytest.mark.parametrize("batch", [4, 8]) -@pytest.mark.parametrize( - "shape", - [(1024, 1024, 1024), (8192, 8192, 8192), (16384, 16384, 16384), (1, 16384, 1664)], -) -@pytest.mark.parametrize( - "mfma_variant", - [ - ScaledMMAType.F32_16x16x128_F8F6F4, - ], -) -@pytest.mark.parametrize( - "enable_scheduling", - [ - SchedulingType.PREFETCH, - SchedulingType.FOUR_STAGE, - ], -) -def testScaledBatchedGemmMXFP4( - batch: int, - shape: tuple[int], +def get_scaled_gemm_template( + shape: tuple[int, int, int], mfma_variant: ScaledMMAType, enable_scheduling: SchedulingType, -): +) -> tuple[WaveCompileOptions, "LaunchableWave"]: # Input sizes B = tkl.sym.B M = tkl.sym.M @@ -332,7 +319,42 @@ def repeat( linearize_shared_access=True, dynamic_symbols=dynamic_symbols, ) + return options, batched_gemm + + +@require_e2e +@require_cdna4 +@pytest.mark.parametrize("batch", [4, 8]) +@pytest.mark.parametrize( + "shape", + [(1024, 1024, 1024), (8192, 8192, 8192), (16384, 16384, 16384), (1, 16384, 1664)], +) +@pytest.mark.parametrize( + "mfma_variant", + [ + ScaledMMAType.F32_16x16x128_F8F6F4, + ], +) +@pytest.mark.parametrize( + "enable_scheduling", + [ + SchedulingType.PREFETCH, + SchedulingType.FOUR_STAGE, + ], +) +@use_water_backend_bool("use_water_backend") +def testScaledBatchedGemmMXFP4( + batch: int, + shape: tuple[int, int, int], + mfma_variant: ScaledMMAType, + enable_scheduling: SchedulingType, + use_water_backend: bool, +): + options, batched_gemm = get_scaled_gemm_template( + shape, mfma_variant, enable_scheduling + ) options = set_default_run_config(options) + options.use_water_backend = use_water_backend batched_gemm = wave_compile(options, batched_gemm) linearized_shape = (batch * shape[0], shape[1], shape[2]) @@ -351,6 +373,126 @@ def repeat( torch.testing.assert_close(torch_out, out) +@use_water_backend_bool("use_water_backend") +def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): + shape = (16384, 16384, 16384) + mfma_variant = ScaledMMAType.F32_16x16x128_F8F6F4 + enable_scheduling = SchedulingType.PREFETCH + options, batched_gemm = get_scaled_gemm_template( + shape, mfma_variant, enable_scheduling + ) + options.target = "gfx950" + options.minimize_shared_allocs = False + # options.use_global_to_shared = True + options.dump_intermediates = tmp_path + options.use_water_backend = use_water_backend + batched_gemm = wave_compile(options, batched_gemm) + asm_files = glob_asm_files(tmp_path) + + assert len(asm_files) == 1, "Expected 1 ASM file" + text = asm_files[0].read_text() + + metadata = extract_kernel_metadata(text) + + # We encode the exact registers and wait counts count as we want to know if + # they suddenly change dur to backend or upstream MLIR changes. + if use_water_backend: + vgpr_count = 166 + vgpr_spill_count = 0 + sgpr_count = 46 + sgpr_spill_count = 0 + waitcounts = [ + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(7)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(5)", + "s_waitcnt vmcnt(4)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(14)", + "s_waitcnt lgkmcnt(12)", + "s_waitcnt lgkmcnt(8)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + ] + else: + vgpr_count = 160 + vgpr_spill_count = 0 + sgpr_count = 46 + sgpr_spill_count = 0 + waitcounts = [ + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(7)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(5)", + "s_waitcnt vmcnt(4)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + ] + + assert ( + metadata.vgpr_count == vgpr_count + ), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}" + assert ( + metadata.vgpr_spill_count == vgpr_spill_count + ), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}" + assert ( + metadata.sgpr_count == sgpr_count + ), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}" + assert ( + metadata.sgpr_spill_count == sgpr_spill_count + ), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}" + assert ( + metadata.waitcnt_ops == waitcounts + ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" + + @require_e2e @require_cdna4 @pytest.mark.parametrize("shape", [(1024, 1024, 1024), (8192, 8192, 8192)]) From f645249a4075fa7b79a16eba7680719931e0a60a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Dec 2025 11:53:47 +0100 Subject: [PATCH 111/114] test Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_gemm_mxfp_test.py | 190 +++++++++++------------ 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/tests/kernel/wave/wave_gemm_mxfp_test.py b/tests/kernel/wave/wave_gemm_mxfp_test.py index 4f5a70a12..178f00b6e 100644 --- a/tests/kernel/wave/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave/wave_gemm_mxfp_test.py @@ -396,101 +396,101 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): # We encode the exact registers and wait counts count as we want to know if # they suddenly change dur to backend or upstream MLIR changes. - if use_water_backend: - vgpr_count = 166 - vgpr_spill_count = 0 - sgpr_count = 46 - sgpr_spill_count = 0 - waitcounts = [ - "s_waitcnt lgkmcnt(0)", - "s_waitcnt vmcnt(7)", - "s_waitcnt vmcnt(6)", - "s_waitcnt vmcnt(5)", - "s_waitcnt vmcnt(4)", - "s_waitcnt vmcnt(3)", - "s_waitcnt vmcnt(2)", - "s_waitcnt vmcnt(1)", - "s_waitcnt vmcnt(0)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt lgkmcnt(14)", - "s_waitcnt lgkmcnt(12)", - "s_waitcnt lgkmcnt(8)", - "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt vmcnt(6)", - "s_waitcnt vmcnt(3)", - "s_waitcnt vmcnt(2)", - "s_waitcnt vmcnt(1)", - "s_waitcnt vmcnt(0)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt lgkmcnt(7)", - "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", - "s_waitcnt lgkmcnt(3)", - "s_waitcnt lgkmcnt(2)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(0)", - ] - else: - vgpr_count = 160 - vgpr_spill_count = 0 - sgpr_count = 46 - sgpr_spill_count = 0 - waitcounts = [ - "s_waitcnt lgkmcnt(0)", - "s_waitcnt vmcnt(7)", - "s_waitcnt vmcnt(6)", - "s_waitcnt vmcnt(5)", - "s_waitcnt vmcnt(4)", - "s_waitcnt vmcnt(3)", - "s_waitcnt vmcnt(2)", - "s_waitcnt vmcnt(1)", - "s_waitcnt vmcnt(0)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt lgkmcnt(6)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(3)", - "s_waitcnt lgkmcnt(2)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt vmcnt(6)", - "s_waitcnt vmcnt(3)", - "s_waitcnt vmcnt(2)", - "s_waitcnt vmcnt(1)", - "s_waitcnt vmcnt(0)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt lgkmcnt(7)", - "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", - "s_waitcnt lgkmcnt(3)", - "s_waitcnt lgkmcnt(2)", - "s_waitcnt lgkmcnt(1)", - "s_waitcnt lgkmcnt(0)", - ] - - assert ( - metadata.vgpr_count == vgpr_count - ), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}" - assert ( - metadata.vgpr_spill_count == vgpr_spill_count - ), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}" - assert ( - metadata.sgpr_count == sgpr_count - ), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}" - assert ( - metadata.sgpr_spill_count == sgpr_spill_count - ), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}" - assert ( - metadata.waitcnt_ops == waitcounts - ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" + # if use_water_backend: + # vgpr_count = 166 + # vgpr_spill_count = 0 + # sgpr_count = 46 + # sgpr_spill_count = 0 + # waitcounts = [ + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt vmcnt(7)", + # "s_waitcnt vmcnt(6)", + # "s_waitcnt vmcnt(5)", + # "s_waitcnt vmcnt(4)", + # "s_waitcnt vmcnt(3)", + # "s_waitcnt vmcnt(2)", + # "s_waitcnt vmcnt(1)", + # "s_waitcnt vmcnt(0)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt lgkmcnt(14)", + # "s_waitcnt lgkmcnt(12)", + # "s_waitcnt lgkmcnt(8)", + # "s_waitcnt lgkmcnt(5)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt vmcnt(6)", + # "s_waitcnt vmcnt(3)", + # "s_waitcnt vmcnt(2)", + # "s_waitcnt vmcnt(1)", + # "s_waitcnt vmcnt(0)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt lgkmcnt(7)", + # "s_waitcnt lgkmcnt(5)", + # "s_waitcnt lgkmcnt(4)", + # "s_waitcnt lgkmcnt(3)", + # "s_waitcnt lgkmcnt(2)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(0)", + # ] + # else: + # vgpr_count = 160 + # vgpr_spill_count = 0 + # sgpr_count = 46 + # sgpr_spill_count = 0 + # waitcounts = [ + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt vmcnt(7)", + # "s_waitcnt vmcnt(6)", + # "s_waitcnt vmcnt(5)", + # "s_waitcnt vmcnt(4)", + # "s_waitcnt vmcnt(3)", + # "s_waitcnt vmcnt(2)", + # "s_waitcnt vmcnt(1)", + # "s_waitcnt vmcnt(0)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt lgkmcnt(6)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(3)", + # "s_waitcnt lgkmcnt(2)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt vmcnt(6)", + # "s_waitcnt vmcnt(3)", + # "s_waitcnt vmcnt(2)", + # "s_waitcnt vmcnt(1)", + # "s_waitcnt vmcnt(0)", + # "s_waitcnt lgkmcnt(0)", + # "s_waitcnt lgkmcnt(7)", + # "s_waitcnt lgkmcnt(5)", + # "s_waitcnt lgkmcnt(4)", + # "s_waitcnt lgkmcnt(3)", + # "s_waitcnt lgkmcnt(2)", + # "s_waitcnt lgkmcnt(1)", + # "s_waitcnt lgkmcnt(0)", + # ] + + # assert ( + # metadata.vgpr_count == vgpr_count + # ), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}" + # assert ( + # metadata.vgpr_spill_count == vgpr_spill_count + # ), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}" + # assert ( + # metadata.sgpr_count == sgpr_count + # ), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}" + # assert ( + # metadata.sgpr_spill_count == sgpr_spill_count + # ), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}" + # assert ( + # metadata.waitcnt_ops == waitcounts + # ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" @require_e2e From 33dbe44d2220ad8f7457c4fda9d3316615a6aa8c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Dec 2025 11:54:07 +0100 Subject: [PATCH 112/114] check Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 790aacec8..93570df0f 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -404,7 +404,15 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu mlir_asm = module.operation.get_asm() target_chip = options.target - def add_opt(pipeline): + enable_asm_lowering = True + + def add_asm_pass(*args: Any) -> list[Any]: + if enable_asm_lowering: + return [args] + + return [] + + def add_opt(pipeline: Any) -> list[Any]: if options.optimization_level: return [pipeline] @@ -453,15 +461,15 @@ def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any] gpu_func = ("gpu.module", "gpu.func") pipeline = [ - ("water-materialize-reg-copy", {}, gpu_func), - ("water-insert-waitcnt", {}, gpu_func), + *add_asm_pass("water-materialize-reg-copy", {}, gpu_func), + *add_asm_pass("water-insert-waitcnt", {}, gpu_func), "expand-strided-metadata", "lower-affine", *add_opt(canonicalize_cse), *add_opt("loop-invariant-code-motion"), *add_opt("int-range-optimizations"), - ("water-number-registers", {}, gpu_func), - ("water-lower-memory-ops", {"chipset": target_chip}, gpu_func), + *add_asm_pass("water-number-registers", {}, gpu_func), + *add_asm_pass("water-lower-memory-ops", {"chipset": target_chip}, gpu_func), "convert-scf-to-cf", ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), ("water-alloc-to-alloca", {}, "gpu.module"), From 219fbeb43eff102e52e77df68e6957619c265a3d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Dec 2025 11:58:26 +0100 Subject: [PATCH 113/114] test check Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_gemm_mxfp_test.py | 190 +++++++++++------------ 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/tests/kernel/wave/wave_gemm_mxfp_test.py b/tests/kernel/wave/wave_gemm_mxfp_test.py index 178f00b6e..b2c4f477c 100644 --- a/tests/kernel/wave/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave/wave_gemm_mxfp_test.py @@ -396,101 +396,101 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): # We encode the exact registers and wait counts count as we want to know if # they suddenly change dur to backend or upstream MLIR changes. - # if use_water_backend: - # vgpr_count = 166 - # vgpr_spill_count = 0 - # sgpr_count = 46 - # sgpr_spill_count = 0 - # waitcounts = [ - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt vmcnt(7)", - # "s_waitcnt vmcnt(6)", - # "s_waitcnt vmcnt(5)", - # "s_waitcnt vmcnt(4)", - # "s_waitcnt vmcnt(3)", - # "s_waitcnt vmcnt(2)", - # "s_waitcnt vmcnt(1)", - # "s_waitcnt vmcnt(0)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt lgkmcnt(14)", - # "s_waitcnt lgkmcnt(12)", - # "s_waitcnt lgkmcnt(8)", - # "s_waitcnt lgkmcnt(5)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt vmcnt(6)", - # "s_waitcnt vmcnt(3)", - # "s_waitcnt vmcnt(2)", - # "s_waitcnt vmcnt(1)", - # "s_waitcnt vmcnt(0)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt lgkmcnt(7)", - # "s_waitcnt lgkmcnt(5)", - # "s_waitcnt lgkmcnt(4)", - # "s_waitcnt lgkmcnt(3)", - # "s_waitcnt lgkmcnt(2)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(0)", - # ] - # else: - # vgpr_count = 160 - # vgpr_spill_count = 0 - # sgpr_count = 46 - # sgpr_spill_count = 0 - # waitcounts = [ - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt vmcnt(7)", - # "s_waitcnt vmcnt(6)", - # "s_waitcnt vmcnt(5)", - # "s_waitcnt vmcnt(4)", - # "s_waitcnt vmcnt(3)", - # "s_waitcnt vmcnt(2)", - # "s_waitcnt vmcnt(1)", - # "s_waitcnt vmcnt(0)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt lgkmcnt(6)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(3)", - # "s_waitcnt lgkmcnt(2)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt vmcnt(6)", - # "s_waitcnt vmcnt(3)", - # "s_waitcnt vmcnt(2)", - # "s_waitcnt vmcnt(1)", - # "s_waitcnt vmcnt(0)", - # "s_waitcnt lgkmcnt(0)", - # "s_waitcnt lgkmcnt(7)", - # "s_waitcnt lgkmcnt(5)", - # "s_waitcnt lgkmcnt(4)", - # "s_waitcnt lgkmcnt(3)", - # "s_waitcnt lgkmcnt(2)", - # "s_waitcnt lgkmcnt(1)", - # "s_waitcnt lgkmcnt(0)", - # ] - - # assert ( - # metadata.vgpr_count == vgpr_count - # ), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}" - # assert ( - # metadata.vgpr_spill_count == vgpr_spill_count - # ), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}" - # assert ( - # metadata.sgpr_count == sgpr_count - # ), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}" - # assert ( - # metadata.sgpr_spill_count == sgpr_spill_count - # ), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}" - # assert ( - # metadata.waitcnt_ops == waitcounts - # ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" + if use_water_backend: + vgpr_count = 156 + vgpr_spill_count = 0 + sgpr_count = 45 + sgpr_spill_count = 0 + waitcounts = [ + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(7)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(5)", + "s_waitcnt vmcnt(4)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(13)", + "s_waitcnt lgkmcnt(10)", + "s_waitcnt lgkmcnt(8)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + ] + else: + vgpr_count = 160 + vgpr_spill_count = 0 + sgpr_count = 46 + sgpr_spill_count = 0 + waitcounts = [ + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(7)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(5)", + "s_waitcnt vmcnt(4)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(6)", + "s_waitcnt vmcnt(3)", + "s_waitcnt vmcnt(2)", + "s_waitcnt vmcnt(1)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", + ] + + assert ( + metadata.vgpr_count == vgpr_count + ), f"Expected {vgpr_count} VGPRs, got {metadata.vgpr_count}" + assert ( + metadata.vgpr_spill_count == vgpr_spill_count + ), f"Expected {vgpr_spill_count} VGPR spills, got {metadata.vgpr_spill_count}" + assert ( + metadata.sgpr_count == sgpr_count + ), f"Expected {sgpr_count} SGPRs, got {metadata.sgpr_count}" + assert ( + metadata.sgpr_spill_count == sgpr_spill_count + ), f"Expected {sgpr_spill_count} SGPR spills, got {metadata.sgpr_spill_count}" + assert ( + metadata.waitcnt_ops == waitcounts + ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" @require_e2e From 9fb159fbf9bf8c4ebbfedd0918da2ff4b6a23571 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Dec 2025 12:24:02 +0100 Subject: [PATCH 114/114] align registers Signed-off-by: Ivan Butygin --- water/lib/Transforms/WaterNumberRegisters.cpp | 47 +++++++++++-------- water/test/Transforms/number-registers.mlir | 10 ++-- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/water/lib/Transforms/WaterNumberRegisters.cpp b/water/lib/Transforms/WaterNumberRegisters.cpp index 176de9358..063ed7cf4 100644 --- a/water/lib/Transforms/WaterNumberRegisters.cpp +++ b/water/lib/Transforms/WaterNumberRegisters.cpp @@ -55,45 +55,54 @@ class WaterNumberRegistersPass auto func = getOperation(); MLIRContext *ctx = &getContext(); - // TODO: for now, just assign registers sequentially. In the future, - // we need a liveness analysis to assign registers. - unsigned nextRegister = 0; + SmallVector> regCounts; + Type i32 = IntegerType::get(ctx, 32); WalkResult result = func->walk([&](memref::AllocaOp allocaOp) { auto memrefType = allocaOp.getType(); if (!isInRegisterSpace(memrefType)) return WalkResult::advance(); - auto regCountOr = getRegisterCount(memrefType); - if (failed(regCountOr)) { + auto regCount = getRegisterCount(memrefType); + if (failed(regCount)) { allocaOp->emitError( "Cannot allocate dynamic-sized memref in register space"); return WalkResult::interrupt(); } - unsigned regCount = *regCountOr; + regCounts.emplace_back(*regCount, allocaOp); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) + return signalPassFailure(); + + // Sort by register size to reduce register alignment gaps. + llvm::stable_sort(regCounts, [](const std::pair &a, + const std::pair &b) { + return a.first < b.first; + }); + + // TODO: for now, just assign registers sequentially. In the future, + // we need a liveness analysis to assign registers. + unsigned nextRegister = 0; + + for (auto [regCount, op] : regCounts) { + // Align to regCount boundary. + nextRegister = ((nextRegister + regCount - 1) / regCount) * regCount; // Assign starting register number. - allocaOp->setAttr( - "water.vgpr_number", - IntegerAttr::get(IntegerType::get(ctx, 32), nextRegister)); + op->setAttr("water.vgpr_number", IntegerAttr::get(i32, nextRegister)); // Track how many registers this alloca uses. - allocaOp->setAttr("water.vgpr_count", - IntegerAttr::get(IntegerType::get(ctx, 32), regCount)); + op->setAttr("water.vgpr_count", IntegerAttr::get(i32, regCount)); // Advance to next available register. nextRegister += regCount; - - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) - return signalPassFailure(); + } // Attach metadata to function with total register count. - func->setAttr("water.total_vgprs", - IntegerAttr::get(IntegerType::get(ctx, 32), nextRegister)); + func->setAttr("water.total_vgprs", IntegerAttr::get(i32, nextRegister)); } }; diff --git a/water/test/Transforms/number-registers.mlir b/water/test/Transforms/number-registers.mlir index 44f34cacc..ccc32447b 100644 --- a/water/test/Transforms/number-registers.mlir +++ b/water/test/Transforms/number-registers.mlir @@ -1,7 +1,7 @@ // RUN: water-opt %s --pass-pipeline='builtin.module(func.func(water-number-registers))' | FileCheck %s // CHECK-LABEL: func @test_simple_numbering -// CHECK-SAME: attributes {water.total_vgprs = 6 : i32} +// CHECK-SAME: attributes {water.total_vgprs = 8 : i32} func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { %c0 = arith.constant 0 : index @@ -9,12 +9,12 @@ func.func @test_simple_numbering(%arg0: memref<100xf32>) -> f32 { // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 0 : i32} %reg0 = memref.alloca() : memref<1xf32, 128 : i32> - // 4xf32 = 16 bytes = 4 registers, starts at reg 1 - // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 1 : i32} + // 4xf32 = 16 bytes = 4 registers, starts at reg 4 + // CHECK: memref.alloca() {water.vgpr_count = 4 : i32, water.vgpr_number = 4 : i32} %reg1 = memref.alloca() : memref<4xf32, 128 : i32> - // 1xf32 = 4 bytes = 1 register, starts at reg 5 (after reg1) - // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 5 : i32} + // 1xf32 = 4 bytes = 1 register, starts at reg 1 (after reg0) + // CHECK: memref.alloca() {water.vgpr_count = 1 : i32, water.vgpr_number = 1 : i32} %reg2 = memref.alloca() : memref<1xf32, 128 : i32> %subview0 = memref.subview %arg0[%c0] [1] [1] : memref<100xf32> to memref<1xf32, strided<[1], offset: ?>>