diff --git a/tests/kernel/wave/infer_index_exprs.py b/tests/kernel/wave/infer_index_exprs.py new file mode 100644 index 000000000..472f9ac92 --- /dev/null +++ b/tests/kernel/wave/infer_index_exprs.py @@ -0,0 +1,118 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.wave.wave import LaunchableWave +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile + +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.constraints import MMAType +from wave_lang.kernel.wave.utils.general_utils import torch_dtype_to_wave + +import torch + + +# TODO: use the generic template, currently blocked by water not handling wave constraints. +def _get_gemm_kernel( + shape: tuple[int, int, int], + mfma_variant: MMAType, + dtype: torch.dtype = torch.float16, + block_shape: tuple[int, int, int] | None = None, + waves_per_block: tuple[int, int] | None = None, +) -> tuple[LaunchableWave, dict[tkl.IndexSymbol, tkl.IndexExpr]]: + if not block_shape: + # BLOCK_M, BLOCK_N, BLOCK_K + block_shape = (64, 64, 32) + + if not waves_per_block: + # WAVE_M, WAVE_N + waves_per_block = (2, 2) + + assert len(block_shape) == 3, "block_shape needs to be rank 3 for M, N, K." + assert len(waves_per_block) == 2, "waves_per_block needs to be rank 2 for M, N." + + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE + dtype = torch_dtype_to_wave(dtype) + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + + # TODO: dialect expects waves_per_block to be rank 3, so we append a 1 to the end. + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=mfma_variant, + waves_per_block=waves_per_block + (1,), + ) + ] + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype], + b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + # This microkernel encodes the fact that if the iterate + # dimension were tiled, then we would need to materialize a loop. + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + # a_reg: tkw.Register[M, K, dtype] + a_reg = tkw.read(a) + # b_reg: tkw.Register[N, K, dtype] + b_reg = tkw.read(b) + # acc: tkw.Register[M, N, tkl.f32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: block_shape[0], + BLOCK_N: block_shape[1], + BLOCK_K: block_shape[2], + M: shape[0], + N: shape[1], + K: shape[2], + } + return gemm, hyperparams + + +def testGemm(): + gemm, hyperparams = _get_gemm_kernel( + shape=(1024, 1024, 1024), mfma_variant=MMAType.F32_16x16x16_F16 + ) + options = WaveCompileOptions( + subs=hyperparams, + run_bench=False, + check_water_analysis=True, + ) + compiled_gemm = wave_compile(options, gemm) + assert compiled_gemm is not None + + +if __name__ == "__main__": + testGemm() diff --git a/water/include/water/Dialect/Wave/IR/IndexExpr.h b/water/include/water/Dialect/Wave/IR/IndexExpr.h index ccf3acf99..5137e9614 100644 --- a/water/include/water/Dialect/Wave/IR/IndexExpr.h +++ b/water/include/water/Dialect/Wave/IR/IndexExpr.h @@ -8,6 +8,7 @@ #define WATER_DIALECT_WAVE_IR_INDEXEXPR_H #include "mlir/IR/Attributes.h" +#include "water/Dialect/Wave/IR/WaveAttrs.h" #include "llvm/ADT/SetVector.h" namespace mlir { @@ -41,6 +42,26 @@ aggregateAllSymbols(RangeT &&symbolLists, mlir::AffineMap alignMapSymbols(mlir::AffineMap map, llvm::ArrayRef symbols, llvm::ArrayRef allSymbols); + +// Create an index mapping induced by the given workgroup constraint. Combine +// it with the base index mapping if provided. +WaveIndexMappingAttr +applyConstraint(WorkgroupConstraintAttr constraint, + WaveIndexMappingAttr baseMapping = nullptr); + +// Create an index mapping induced by the given tiling constraint. Combine it +// with the base index mapping if provided. +WaveIndexMappingAttr +applyConstraint(TilingConstraintAttr constraint, + WaveIndexMappingAttr baseMapping = nullptr); + +// Create an index mapping induced by the given constraint. Combine it with the +// base index mapping if provided. Call `applyConstraint` if the specific kind +// of constraint is already known. +WaveIndexMappingAttr +applyConstraintGeneric(mlir::Attribute constraint, + WaveIndexMappingAttr baseMapping = nullptr); + } // namespace wave #endif // WATER_DIALECT_WAVE_IR_INDEXEXPR_H diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index f30660b44..b133a94a2 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -104,6 +104,7 @@ def NORMAL_FORM_FULL_TYPES : I32BitEnumAttrCaseGroup< NORMAL_FORM_FUNC_BOUNDARY, NORMAL_FORM_OP_TYPES ], "full_types">; +// When anything changes below, also update the C and Python API. def WaveNormalFormEnum : I32BitEnumAttr<"WaveNormalForm", "", [ I32BitEnumAttrCaseNone<"None", "none">, // Bits. @@ -366,6 +367,19 @@ def WaveSymbolAttr : AttrDef { }]; } +def WaveIterSymbolAttr : AttrDef { + let mnemonic = "iter"; + let description = [{ + Symbol referring to an induction variable that can be used in symbolic + expressions in the Wave dialect. Induction variables may have the same + name as other symbols without clashing with them. + }]; + + let parameters = + (ins StringRefParameter<"name of the induction variable">:$name); + let assemblyFormat = "`<` $name `>`"; +} + def WaveIndexMappingAttr : AttrDef { let mnemonic = "index_mapping"; let description = [{ diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 9af9631f1..736c21335 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -139,7 +139,8 @@ def MmaOp : WaveOp<"mma", def IterateOp : Op]> { + ["areTypesCompatible", "getEntrySuccessorOperands"]>, + DeclareOpInterfaceMethods]> { let summary = "Executes the body repeatedly"; let description = [{ Intrinsically sequential iterative execution that is akin to a loop with @@ -215,7 +216,9 @@ def YieldOp : Op { +def AllocateOp : WaveOp<"allocate", [ + DeclareOpInterfaceMethods + ]> { let summary = "Represents an allocation in an address space"; let description = [{ Allocates memory for a Wave tensor in the address space indicated by the diff --git a/water/include/water/c/Dialects.h b/water/include/water/c/Dialects.h index 02f902a71..53de8ff85 100644 --- a/water/include/water/c/Dialects.h +++ b/water/include/water/c/Dialects.h @@ -7,6 +7,7 @@ #ifndef WATER_C_DIALECTS_H #define WATER_C_DIALECTS_H +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #ifdef __cplusplus @@ -15,6 +16,9 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Wave, wave); +/// Register the Wave dialect passes. +MLIR_CAPI_EXPORTED void mlirWaveDialectRegisterPasses(); + //===---------------------------------------------------------------------===// // Wave Dialect Constants //===---------------------------------------------------------------------===// @@ -36,6 +40,27 @@ mlirWaveSymbolAttrGet(MlirContext mlirCtx, MlirStringRef symbolName); /// Returns the typeID of a WaveSymbolAttr. MLIR_CAPI_EXPORTED MlirTypeID mlirWaveSymbolAttrGetTypeID(); +/// Gets the name of a WaveSymbolAttr. +MLIR_CAPI_EXPORTED MlirStringRef mlirWaveSymbolAttrGetName(MlirAttribute attr); + +//===---------------------------------------------------------------------===// +// WaveIterSymbolAttr +//===---------------------------------------------------------------------===// + +/// Checks whether the given MLIR attribute is a WaveIterSymbolAttr. +MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveIterSymbolAttr(MlirAttribute attr); + +/// Creates a new WaveIterSymbolAttr with the given induction variable name. +MLIR_CAPI_EXPORTED MlirAttribute +mlirWaveIterSymbolAttrGet(MlirContext mlirCtx, MlirStringRef symbolName); + +/// Returns the typeID of a WaveIterSymbolAttr. +MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIterSymbolAttrGetTypeID(); + +/// Gets the induction variable name. +MLIR_CAPI_EXPORTED MlirStringRef +mlirWaveIterSymbolAttrGetName(MlirAttribute attr); + //===---------------------------------------------------------------------===// // WaveIndexSymbolAttr //===---------------------------------------------------------------------===// @@ -86,6 +111,26 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexMappingAttrGet( /// Returns the typeID of a WaveIndexMappingAttr. MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexMappingAttrGetTypeID(); +/// Get the offset from a WaveIndexMappingAttr. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirWaveIndexMappingAttrGetStart(MlirAttribute attr); + +/// Get the step from a WaveIndexMappingAttr. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirWaveIndexMappingAttrGetStep(MlirAttribute attr); + +/// Get the stride from a WaveIndexMappingAttr. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirWaveIndexMappingAttrGetStride(MlirAttribute attr); + +/// Get the number of (input) symbols. +MLIR_CAPI_EXPORTED intptr_t +mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr); + +/// Get the (input) symbol at the given index. +MLIR_CAPI_EXPORTED MlirAttribute +mlirWaveIndexMappingAttrGetSymbol(MlirAttribute attr, intptr_t index); + //===---------------------------------------------------------------------===// // WaveHyperparameterAttr //===---------------------------------------------------------------------===// diff --git a/water/lib/CAPI/CMakeLists.txt b/water/lib/CAPI/CMakeLists.txt index ce823e785..25a1cd762 100644 --- a/water/lib/CAPI/CMakeLists.txt +++ b/water/lib/CAPI/CMakeLists.txt @@ -4,4 +4,5 @@ add_mlir_public_c_api_library(WaterCAPI LINK_LIBS PUBLIC MLIRCAPIIR MLIRWaveDialect + MLIRWaveTransforms ) diff --git a/water/lib/CAPI/Dialects.cpp b/water/lib/CAPI/Dialects.cpp index 3dab55e08..458db6fef 100644 --- a/water/lib/CAPI/Dialects.cpp +++ b/water/lib/CAPI/Dialects.cpp @@ -12,10 +12,17 @@ #include "water/Dialect/Wave/IR/WaveAttrs.h" #include "water/Dialect/Wave/IR/WaveDialect.h" +#include "water/Dialect/Wave/Transforms/Passes.h" #include "water/c/Dialects.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Wave, wave, ::wave::WaveDialect) +//===---------------------------------------------------------------------===// +// Wave Dialect Passes +//===---------------------------------------------------------------------===// + +void mlirWaveDialectRegisterPasses() { wave::registerPasses(); } + //===---------------------------------------------------------------------===// // Wave Dialect Constants //===---------------------------------------------------------------------===// @@ -42,6 +49,33 @@ MlirTypeID mlirWaveSymbolAttrGetTypeID() { return wrap(mlir::TypeID::get()); } +MlirStringRef mlirWaveSymbolAttrGetName(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getName()); +} + +//===---------------------------------------------------------------------===// +// WaveIterSymbolAttr +//===---------------------------------------------------------------------===// + +bool mlirAttributeIsAWaveIterSymbolAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirWaveIterSymbolAttrGet(MlirContext mlirCtx, + MlirStringRef symbolNameStrRef) { + mlir::MLIRContext *ctx = unwrap(mlirCtx); + llvm::StringRef symbolName = unwrap(symbolNameStrRef); + return wrap(wave::WaveIterSymbolAttr::get(ctx, symbolName)); +} + +MlirTypeID mlirWaveIterSymbolAttrGetTypeID() { + return wrap(mlir::TypeID::get()); +} + +MlirStringRef mlirWaveIterSymbolAttrGetName(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getName()); +} + //===---------------------------------------------------------------------===// // WaveIndexSymbolAttr //===---------------------------------------------------------------------===// @@ -91,7 +125,8 @@ MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx, assert(llvm::all_of( symbolAttrs, - llvm::IsaPred) && + llvm::IsaPred) && "expected mapping to contain only WaveSymbolAttr or " "WaveIndexSymbolAttr attributes"); @@ -103,6 +138,30 @@ MlirTypeID mlirWaveIndexMappingAttrGetTypeID() { return wrap(mlir::TypeID::get()); } +MlirAffineMap mlirWaveIndexMappingAttrGetStart(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getStart()); +} + +MlirAffineMap mlirWaveIndexMappingAttrGetStep(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getStep()); +} + +MlirAffineMap mlirWaveIndexMappingAttrGetStride(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getStride()); +} + +intptr_t mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getSymbols() + .size(); +} + +MlirAttribute mlirWaveIndexMappingAttrGetSymbol(MlirAttribute attr, + intptr_t index) { + return wrap( + llvm::cast(unwrap(attr)).getSymbols()[index]); +} + //===---------------------------------------------------------------------===// // WaveHyperparameterAttr //===---------------------------------------------------------------------===// @@ -215,9 +274,10 @@ MlirAttribute mlirWaveExprListAttrGet(MlirAttribute *symbolNames, assert(llvm::all_of( symbolAttrs, - llvm::IsaPred) && - "expected mapping to contain only WaveSymbolAttr or " - "WaveIndexSymbolAttr attributes"); + llvm::IsaPred) && + "expected mapping to contain only WaveSymbolAttr, " + "WaveIndexSymbolAttr or WaveIterSymbolAttr attributes"); return wrap(wave::WaveExprListAttr::get(ctx, symbolAttrs, unwrap(map))); } diff --git a/water/lib/Dialect/Wave/IR/IndexExpr.cpp b/water/lib/Dialect/Wave/IR/IndexExpr.cpp index ebbe42c71..39d56fd1d 100644 --- a/water/lib/Dialect/Wave/IR/IndexExpr.cpp +++ b/water/lib/Dialect/Wave/IR/IndexExpr.cpp @@ -7,11 +7,14 @@ #include "water/Dialect/Wave/IR/IndexExpr.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "water/Dialect/Wave/IR/WaveAttrs.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; -AffineMap wave::alignMapSymbols(AffineMap map, ArrayRef symbols, - ArrayRef allSymbols) { +namespace wave { +AffineMap alignMapSymbols(AffineMap map, ArrayRef symbols, + ArrayRef allSymbols) { assert(map.getNumDims() == 0 && "maps should not involve dimensions"); MLIRContext *ctx = map.getContext(); @@ -29,5 +32,112 @@ AffineMap wave::alignMapSymbols(AffineMap map, ArrayRef symbols, for (AffineExpr expr : map.getResults()) remapped.push_back(expr.replaceSymbols(newSymbols)); - return mlir::AffineMap::get(/*dimCount=*/0, newNumSyms, remapped, ctx); + return AffineMap::get(/*dimCount=*/0, newNumSyms, remapped, ctx); } +} // namespace wave + +// Get the (indexed) affine symbol expression corresponding to the given symbol, +// insert it into the list if it isn't already present. +static AffineExpr +getOrInsertSymbolExpr(Attribute symbol, + llvm::SmallVectorImpl &symbolNames) { + auto it = llvm::find(symbolNames, symbol); + unsigned position = [&] { + if (it != symbolNames.end()) + return static_cast(std::distance(symbolNames.begin(), it)); + symbolNames.push_back(symbol); + return static_cast(symbolNames.size() - 1); + }(); + return getAffineSymbolExpr(position, symbol.getContext()); +} + +// Helper function to create an index mapping from a symbol expression and +// constraint tile size. Combines with base mapping if provided. +static wave::WaveIndexMappingAttr createIndexMappingFromSymbolExpr( + AffineExpr symbolExpr, ArrayRef symbols, AffineMap tileSizeMap, + MLIRContext *context, wave::WaveIndexMappingAttr baseMapping) { + assert(tileSizeMap.getNumResults() == 1 && + "expected a single result expression in affine map"); + AffineMap map = AffineMap::get( + /*dimCount=*/0, symbols.size(), symbolExpr * tileSizeMap.getResult(0)); + if (baseMapping == nullptr) + return wave::WaveIndexMappingAttr::get(context, symbols, map, AffineMap(), + AffineMap()); + + SmallVector allSymbols; + wave::aggregateAllSymbols( + std::initializer_list>{baseMapping.getSymbols(), + symbols}, + allSymbols); + + AffineMap baseStart = wave::alignMapSymbols( + baseMapping.getStart(), baseMapping.getSymbols(), allSymbols); + AffineMap baseStep = wave::alignMapSymbols( + baseMapping.getStep(), baseMapping.getSymbols(), allSymbols); + AffineMap baseStride = wave::alignMapSymbols( + baseMapping.getStride(), baseMapping.getSymbols(), allSymbols); + map = wave::alignMapSymbols(map, symbols, allSymbols); + map = AffineMap::get(/*dimCount=*/0, allSymbols.size(), + baseStart.getResult(0) + map.getResult(0)); + return wave::WaveIndexMappingAttr::get(context, allSymbols, map, baseStep, + baseStride); +} + +namespace wave { + +WaveIndexMappingAttr applyConstraint(WorkgroupConstraintAttr constraint, + WaveIndexMappingAttr baseMapping) { + llvm::SmallVector symbols = + llvm::map_to_vector(constraint.getTileSize().getSymbols(), + [](mlir::Attribute symbol) { return symbol; }); + + mlir::MLIRContext *context = constraint.getContext(); + + // TODO: we should just use workgroup attributes in expressions directly. + WaveIndexSymbol symbol = [&] { + switch (constraint.getWorkgroupDim().getValue()) { + case WaveWorkgroupDim::X: + return WaveIndexSymbol::WORKGROUP_0; + case WaveWorkgroupDim::Y: + return WaveIndexSymbol::WORKGROUP_1; + case WaveWorkgroupDim::Z: + return WaveIndexSymbol::WORKGROUP_2; + } + llvm::report_fatal_error("unsupported workgroup dimension"); + }(); + mlir::AffineExpr symbolExpr = + getOrInsertSymbolExpr(WaveIndexSymbolAttr::get(context, symbol), symbols); + + return createIndexMappingFromSymbolExpr(symbolExpr, symbols, + constraint.getTileSize().getMap(), + context, baseMapping); +} + +WaveIndexMappingAttr applyConstraint(TilingConstraintAttr constraint, + WaveIndexMappingAttr baseMapping) { + llvm::SmallVector symbols = + llvm::map_to_vector(constraint.getTileSize().getSymbols(), + [](mlir::Attribute symbol) { return symbol; }); + + mlir::MLIRContext *context = constraint.getContext(); + mlir::AffineExpr symbolExpr = getOrInsertSymbolExpr( + WaveIterSymbolAttr::get(context, constraint.getDim().getName()), symbols); + + return createIndexMappingFromSymbolExpr(symbolExpr, symbols, + constraint.getTileSize().getMap(), + context, baseMapping); +} + +WaveIndexMappingAttr applyConstraintGeneric(mlir::Attribute constraint, + WaveIndexMappingAttr baseMapping) { + return llvm::TypeSwitch(constraint) + .Case( + [&](auto constraint) { + // This double dispatching is necessary in absence of interfaces to + // dispatch to a class method based on a specific type. + return applyConstraint(constraint, baseMapping); + }) + .Default([&](mlir::Attribute constraint) { return nullptr; }); +} + +} // namespace wave diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 75ca44ded..b1f62c0a7 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -43,8 +43,16 @@ using namespace wave; /// Helper function to parse an affine wave expression with the wave /// symbol names passed in `names`. static ParseResult parseExprWithNames(ArrayRef names, - AffineExpr &outExpr, AsmParser &parser) { + AffineExpr &outExpr, AsmParser &parser, + bool allowNull = false) { MLIRContext *context = parser.getContext(); + if (allowNull && succeeded(parser.parseOptionalLess()) && + succeeded(parser.parseKeyword("NULL")) && + succeeded(parser.parseGreater())) { + outExpr = nullptr; + return success(); + } + SmallVector> symbolSet; symbolSet.reserve(names.size()); for (auto [i, nm] : llvm::enumerate(names)) @@ -60,6 +68,8 @@ static ParseResult parseExprWithNames(ArrayRef names, // It textually substitutes 's' occurrences with the corresponding names from // the provided `names` array. std::string stringifyWithNames(AffineMap map, ArrayRef names) { + if (!map) + return ""; AffineExpr expr = map.getResult(0); std::string exprStr; llvm::raw_string_ostream os(exprStr); @@ -86,6 +96,34 @@ std::string stringifyWithNames(AffineMap map, ArrayRef names) { return exprStr; } +// Populate `names` with the names of symbols to be used in the expression +// syntax. If new names are generated on-the-fly, store them in +// `owningSymbolNames`. It is the caller's responsibility to keep those names +// alive at least as long as it needs `names`, which only stores references. +LogicalResult +getExprSymbolNames(ArrayRef symbols, + SmallVectorImpl &names, + SmallVectorImpl> &owningSymbolNames, + function_ref emitError) { + names.reserve(names.size() + symbols.size()); + for (Attribute attr : symbols) { + if (auto sym = dyn_cast(attr)) { + names.push_back(sym.getName()); + } else if (auto sym = dyn_cast(attr)) { + names.push_back(stringifyWaveIndexSymbol(sym.getValue())); + } else if (auto sym = dyn_cast(attr)) { + llvm::raw_svector_ostream os(owningSymbolNames.emplace_back()); + os << "_Iter_" << sym.getName(); + names.push_back(os.str()); + } else { + emitError("expected symbol names to be one of WaveSymbolAttr, " + "WaveIndexSymbolAttr or WaveIterSymbolAttr"); + return failure(); + } + } + return success(); +} + Attribute WaveIndexMappingAttr::parse(AsmParser &parser, Type type) { // Parse custom syntax: '[' symbol-names ']' '->' '(' start, step, stride ')' // This preserves meaningful symbol names while leveraging the existing @@ -109,28 +147,26 @@ Attribute WaveIndexMappingAttr::parse(AsmParser &parser, Type type) { MLIRContext *context = parser.getContext(); SmallVector symbolNames; - symbolNames.reserve(symbolNameAttrs.size()); - for (Attribute attr : symbolNameAttrs) { - if (auto sym = dyn_cast(attr)) { - symbolNames.push_back(sym.getName()); - } else if (auto sym = dyn_cast(attr)) { - symbolNames.push_back(stringifyWaveIndexSymbol(sym.getValue())); - } else { - parser.emitError(parser.getCurrentLocation(), - "expected symbol names to be either a WaveSymbolAttr or " - "WaveIndexSymbolAttr"); - return {}; - } + SmallVector> owningSymbolNames; + if (failed(getExprSymbolNames(symbolNameAttrs, symbolNames, owningSymbolNames, + [&](StringRef message) { + parser.emitError(parser.getCurrentLocation(), + message); + }))) { + return {}; } AffineExpr startExpr; AffineExpr stepExpr; AffineExpr strideExpr; - if (failed(parseExprWithNames(symbolNames, startExpr, parser)) || + if (failed(parseExprWithNames(symbolNames, startExpr, parser, + /*allowNull=*/true)) || parser.parseComma() || - failed(parseExprWithNames(symbolNames, stepExpr, parser)) || + failed(parseExprWithNames(symbolNames, stepExpr, parser, + /*allowNull=*/true)) || parser.parseComma() || - failed(parseExprWithNames(symbolNames, strideExpr, parser)) || + failed(parseExprWithNames(symbolNames, strideExpr, parser, + /*allowNull=*/true)) || parser.parseRParen()) { parser.emitError( parser.getCurrentLocation(), @@ -138,13 +174,21 @@ Attribute WaveIndexMappingAttr::parse(AsmParser &parser, Type type) { return {}; } - // Build maps - auto startMap = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/symbolNames.size(), startExpr, context); - auto stepMap = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/symbolNames.size(), stepExpr, context); - auto strideMap = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/symbolNames.size(), strideExpr, context); + auto startMap = startExpr + ? AffineMap::get( + /*numDims=*/0, /*numSymbols=*/symbolNames.size(), + startExpr, context) + : AffineMap(); + auto stepMap = stepExpr + ? AffineMap::get( + /*numDims=*/0, /*numSymbols=*/symbolNames.size(), + stepExpr, context) + : AffineMap(); + auto strideMap = strideExpr + ? AffineMap::get( + /*numDims=*/0, /*numSymbols=*/symbolNames.size(), + strideExpr, context) + : AffineMap(); return get(context, symbolNameAttrs, startMap, stepMap, strideMap); } @@ -162,14 +206,14 @@ void WaveIndexMappingAttr::print(AsmPrinter &printer) const { printer << "] -> "; SmallVector names; - names.reserve(symbols.size()); - for (Attribute symbolAttr : symbols) { - if (auto symbol = llvm::dyn_cast(symbolAttr)) - names.push_back(symbol.getName()); - else if (auto symbol = llvm::dyn_cast(symbolAttr)) - names.emplace_back(wave::stringifyWaveIndexSymbol(symbol.getValue())); - else - llvm_unreachable("Unexpected symbol attribute type"); + SmallVector> owningSymbolNames; + if (failed(getExprSymbolNames( + symbols, names, owningSymbolNames, [&](StringRef message) { + // TODO: double-check that printer doesn't crash on malformed + // attributes. + llvm_unreachable("Unexpected symbol attribute type"); + }))) { + return; } // All three maps share the same symbol set and order. std::string startStr = stringifyWithNames(getStart(), names); @@ -183,10 +227,11 @@ LogicalResult WaveIndexMappingAttr::verify(function_ref emitError, ArrayRef symbols, AffineMap start, AffineMap step, AffineMap stride) { - if (!llvm::all_of(symbols, - llvm::IsaPred)) - return emitError() << "expected all symbols to be a WaveSymbolAttr or " - "WaveIndexSymbolAttr"; + if (!llvm::all_of(symbols, llvm::IsaPred)) { + return emitError() << "expected all symbols to be a WaveSymbolAttr, " + "WaveIndexSymbolAttr or WaveIterSymbolAttr"; + } return success(); } @@ -226,6 +271,9 @@ WaveSymbolAttr::verify(function_ref emitError, diag.attachNote() << "Did you mean to use WaveIndexSymbolAttr instead?"; return diag; } + if (name.starts_with("_")) + return emitError() + << "symbols names starting with '_' are reserved for internal use"; return llvm::success(); } @@ -286,18 +334,13 @@ Attribute WaveExprListAttr::parse(AsmParser &parser, Type) { if (parser.parseArrow()) return {}; MLIRContext *context = parser.getContext(); - symbolNames.reserve(symbolNameAttrs.size()); - for (Attribute attr : symbolNameAttrs) { - if (auto sym = dyn_cast(attr)) { - symbolNames.push_back(sym.getName()); - } else if (auto sym = dyn_cast(attr)) { - symbolNames.push_back(stringifyWaveIndexSymbol(sym.getValue())); - } else { - parser.emitError(parser.getCurrentLocation(), - "expected symbol names to be either a WaveSymbolAttr or " - "WaveIndexSymbolAttr"); - return {}; - } + SmallVector> owningSymbolNames; + if (failed(getExprSymbolNames(symbolNameAttrs, symbolNames, owningSymbolNames, + [&](StringRef message) { + parser.emitError(parser.getCurrentLocation(), + message); + }))) { + return {}; } SmallVector results; @@ -337,14 +380,14 @@ void WaveExprListAttr::print(mlir::AsmPrinter &printer) const { }); printer << "] -> ("; - names.reserve(symbols.size()); - for (Attribute symbolAttr : symbols) { - if (auto symbol = llvm::dyn_cast(symbolAttr)) - names.push_back(symbol.getName()); - else if (auto symbol = llvm::dyn_cast(symbolAttr)) - names.emplace_back(wave::stringifyWaveIndexSymbol(symbol.getValue())); - else - llvm_unreachable("Unexpected symbol attribute type"); + SmallVector> owningSymbolNames; + if (failed(getExprSymbolNames( + symbols, names, owningSymbolNames, [&](StringRef message) { + // TODO: double-check that printer doesn't crash on malformed + // attributes. + llvm_unreachable("Unexpected symbol attribute type"); + }))) { + return; } // We have one map with N results. For each result expr, make a 1-result map @@ -365,10 +408,10 @@ void WaveExprListAttr::print(mlir::AsmPrinter &printer) const { LogicalResult WaveExprListAttr::verify(function_ref emitError, ArrayRef symbols, AffineMap map) { - if (!llvm::all_of(symbols, - llvm::IsaPred)) - return emitError() << "expected all symbols to be a WaveSymbolAttr or " - "WaveIndexSymbolAttr"; + if (!llvm::all_of(symbols, llvm::IsaPred)) + return emitError() << "expected all symbols to be a WaveSymbolAttr, " + "WaveIndexSymbolAttr or WaveIterSymbolAttr"; return success(); } diff --git a/water/lib/Dialect/Wave/IR/WaveDialect.cpp b/water/lib/Dialect/Wave/IR/WaveDialect.cpp index 49f479643..bea066d6c 100644 --- a/water/lib/Dialect/Wave/IR/WaveDialect.cpp +++ b/water/lib/Dialect/Wave/IR/WaveDialect.cpp @@ -119,12 +119,10 @@ static llvm::LogicalResult verifyAttributeHyperparamUses( } } } - // TODO: somehow get rid of hardcoded magic names (_ARG). mlir::WalkResult walkResult = namedAttr.getValue().walk([&](wave::WaveSymbolAttr symbolAttr) { usedSymbols.insert(symbolAttr.getName()); - if (hyperparam.getMapping().contains(symbolAttr.getName()) || - symbolAttr.getName().starts_with("_ARG")) + if (hyperparam.getMapping().contains(symbolAttr.getName())) return mlir::WalkResult::advance(); mlir::InFlightDiagnostic diag = emitError() diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index 8b9ce4ece..e855a4c50 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "water/Dialect/Wave/IR/WaveInterfaces.h" +#include "mlir/IR/AffineExpr.h" #include "water/Dialect/Wave/IR/IndexExpr.h" #include "water/Dialect/Wave/IR/WaveAttrs.h" #include "water/Dialect/Wave/IR/WaveDialect.h" @@ -13,7 +14,9 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringSet.h" using namespace mlir; @@ -71,7 +74,7 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { unsigned declared = mapping.getSymbols().size(); if (startMap.getNumSymbols() != declared || stepMap.getNumSymbols() != declared || - strideMap.getNumSymbols() != declared) + strideMap.getNumSymbols() != declared) { return op->emitError( "inconsistent symbol count between symbol_names and " "affine maps for index symbol '") @@ -79,6 +82,35 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { << ", got start=" << startMap.getNumSymbols() << ", step=" << stepMap.getNumSymbols() << ", stride=" << strideMap.getNumSymbols() << ")"; + } + + for (auto symbol : mapping.getSymbols()) { + auto iterSymbol = dyn_cast(symbol); + if (!iterSymbol) + continue; + + bool found = false; + for (Operation *currentOp = op->getParentOp(); currentOp != nullptr; + currentOp = currentOp->getParentOp()) { + // TODO: we don't want to depend on the IterateOp specifically from + // the interface (though technically we can), so we use the opaque + // attribute name. We should add something like a "wave control flow + // interface" that would provide it without hardcoded strings. + auto parentIterSymbol = + currentOp->getAttrOfType("iterator"); + if (!parentIterSymbol) + continue; + if (parentIterSymbol.getName() == iterSymbol.getName()) { + found = true; + break; + } + } + if (!found) { + return op->emitError("index expression uses iterator symbol ") + << iterSymbol.getName() + << " which is not defined by any parent op"; + } + } } } return success(); @@ -582,22 +614,220 @@ wave::IndexExprsAnalysisInit::create(mlir::Location loc, return initObject; } -// Populate `positions` with the positions of thread and GPR-related symbols in -// the given list. -static void -getThreadLikeSymbolPositions(llvm::ArrayRef symbols, - llvm::SmallVectorImpl &positions) { - for (auto &&[i, symbol] : llvm::enumerate(symbols)) { - auto indexSymbol = llvm::dyn_cast(symbol); - if (!indexSymbol) +// Create a new map with per-result sum between a and b maps. +static AffineMap addMaps(AffineMap a, AffineMap b) { + assert(a.getNumResults() == b.getNumResults() && + "expected both maps to have the same number of expressions"); + SmallVector subtracted = llvm::map_to_vector( + llvm::zip_equal(a.getResults(), b.getResults()), + [&](auto &&pair) { return std::get<0>(pair) + std::get<1>(pair); }); + return AffineMap::get(a.getNumDims(), a.getNumSymbols(), subtracted, + a.getContext()); +} + +// Create a new map with per-result difference between a and b maps. +static AffineMap subtractMaps(AffineMap a, AffineMap b) { + assert(a.getNumResults() == b.getNumResults() && + "expected both maps to have the same number of expressions"); + SmallVector subtracted = llvm::map_to_vector( + llvm::zip_equal(a.getResults(), b.getResults()), + [&](auto &&pair) { return std::get<0>(pair) - std::get<1>(pair); }); + return AffineMap::get(a.getNumDims(), a.getNumSymbols(), subtracted, + a.getContext()); +} + +// Return the list of thread-like index symbols. +// TODO: It would be nice to cache the list somehow, but we need to tie it to +// the context and ensure thread safety, potentially by storing it as an +// attribute or some other named/typed entity in the context object. +static SmallVector getThreadLikeIndexSymbols(MLIRContext *ctx) { + return llvm::map_to_vector( + ArrayRef{wave::WaveIndexSymbol::THREAD_0, wave::WaveIndexSymbol::THREAD_1, + wave::WaveIndexSymbol::THREAD_2, + wave::WaveIndexSymbol::GPR_NUMBER}, + [&](wave::WaveIndexSymbol symbol) -> Attribute { + return wave::WaveIndexSymbolAttr::get(ctx, symbol); + }); +} + +// Return the list of index symbols other than thread-like. +static SmallVector getNonThreadLikeIndexSymbols(MLIRContext *ctx) { + return llvm::map_to_vector(ArrayRef{wave::WaveIndexSymbol::WORKGROUP_0, + wave::WaveIndexSymbol::WORKGROUP_1, + wave::WaveIndexSymbol::WORKGROUP_2, + wave::WaveIndexSymbol::DEVICE_DIM_0, + wave::WaveIndexSymbol::DEVICE_DIM_1, + wave::WaveIndexSymbol::DEVICE_DIM_2}, + [&](wave::WaveIndexSymbol symbol) -> Attribute { + return wave::WaveIndexSymbolAttr::get(ctx, + symbol); + }); +} + +// Get the positions of `symbols` in `allSymbols`. Missing symbols are ignored. +SmallVector getPositionsOfSymbols(ArrayRef symbols, + ArrayRef allSymbols) { + // Find positions of threadLikeIndexSymbols in symbols + SmallVector positions; + for (Attribute symbol : symbols) { + auto it = llvm::find(allSymbols, symbol); + if (it != allSymbols.end()) + positions.push_back(std::distance(allSymbols.begin(), it)); + } + return positions; +} + +// Return true if any expression in the map is a function of a symbol at any of +// the given positions. +static bool isIndexExprMapFunctionOf(AffineMap map, + ArrayRef positions) { + return llvm::any_of(positions, [&](unsigned position) { + return map.isFunctionOfSymbol(position); + }); +} + +// Affine maps used in an index expression conceptually consists of multiple +// additive components: +// +// - thread independent component (workgroup and device indices) +// - thread dependent component (thread and GPR indices) +// - one component per iter dimension +// +// Two maps can be joined if, for all pairwise components: +// +// - the components are equal; +// - the component is absent from one of the maps. +// +// The join is then the sum of unique components from both maps. +// +// We check this by examining the difference between the two maps, which should +// only contain symbols absent from one of the maps, i.e. symbols from the +// symmetric difference of the symbol sets or, alternatively, not contain any +// symbols from the intersection of the symbol sets. The difference should also +// not contain a constant term. +// +// Additional care is taken for index (non-iter) dimensions as they must appear +// or not appear simultaneously. For example, lhs may only have thread 0 index +// and rhs may only have thread 1 index, so the difference will depend on both +// thread 0 and thread 1 indices without either of them being common, so the +// regular check won't detect that. Check for dependency on any such symbol +// instead. +// +// The function takes the list of symbols used in LHS and RHS and separately a +// list containing a union thereof and a list of positions in that list of the +// common symbols. This is done for efficiency reasons to avoid re-computing +// these several times when handling start/size/stride maps that share the +// symbol lists. +// +// TODO: consider having separate expressions for each component for simplicity; +// even further, consider having a lattice that is a collection of constraints +// applicable to the value + metadata (like it being used in LHS/RHS/Acc of an +// MMA) without creating the expression immediately. +static FailureOr getIndexExprsJoinedMap( + AffineMap lhs, AffineMap rhs, ArrayRef lhsSymbols, + ArrayRef rhsSymbols, ArrayRef allSymbols, + ArrayRef disallowedSymbols) { + if (!lhs) + return rhs; + if (!rhs) + return lhs; + + lhs = wave::alignMapSymbols(lhs, lhsSymbols, allSymbols); + rhs = wave::alignMapSymbols(rhs, rhsSymbols, allSymbols); + + if (lhs == rhs) + return lhs; + + AffineMap difference = simplifyAffineMap(subtractMaps(lhs, rhs)); + + MLIRContext *ctx = rhs.getContext(); + SmallVector threadLikePositions = + getPositionsOfSymbols(getThreadLikeIndexSymbols(ctx), allSymbols); + SmallVector nonThreadLikePositions = + getPositionsOfSymbols(getNonThreadLikeIndexSymbols(ctx), allSymbols); + if (isIndexExprMapFunctionOf(difference, threadLikePositions) && + isIndexExprMapFunctionOf(lhs, threadLikePositions) && + isIndexExprMapFunctionOf(rhs, threadLikePositions)) { + return failure(); + } + if (isIndexExprMapFunctionOf(difference, nonThreadLikePositions) && + isIndexExprMapFunctionOf(lhs, nonThreadLikePositions) && + isIndexExprMapFunctionOf(rhs, nonThreadLikePositions)) { + return failure(); + } + + SmallVector symReplacements(allSymbols.size(), + getAffineConstantExpr(0, ctx)); + for (AffineExpr expr : difference.getResults()) { + // The symbolic part of the difference should not depend on any of the + // disallowed symbols. + for (unsigned symbol : disallowedSymbols) { + if (expr.isFunctionOfSymbol(symbol)) + return failure(); + } + + // The numeric part of the difference should be zero. + AffineExpr differenceRemainder = expr.replaceSymbols(symReplacements); + if (auto constant = llvm::dyn_cast(differenceRemainder); + !constant || constant.getValue() != 0) + return failure(); + } + + // Obtain a part of the RHS map that is only a function of RHS-specific + // symbols. For this, we replace all symbols appearing in the LHS map with + // zero. Symbol replacements contain zeros at this point. Reuse that but set + // RHS-only symbols to be replaced with themselves. + SmallVector lhsSymbolPositions = + getPositionsOfSymbols(lhsSymbols, allSymbols); + for (unsigned i = 0, e = symReplacements.size(); i < e; ++i) { + if (llvm::is_contained(lhsSymbolPositions, i)) continue; - if (llvm::is_contained({wave::WaveIndexSymbol::THREAD_0, - wave::WaveIndexSymbol::THREAD_1, - wave::WaveIndexSymbol::THREAD_2, - wave::WaveIndexSymbol::GPR_NUMBER}, - indexSymbol.getValue())) - positions.push_back(i); + symReplacements[i] = getAffineSymbolExpr(i, ctx); + } + AffineMap rhsOnly = rhs.replaceDimsAndSymbols( + {}, symReplacements, /*numResultDims=*/0, rhs.getNumSymbols()); + return simplifyAffineMap(addMaps(lhs, rhsOnly)); +} + +// Join two concrete index expressions mappings by joining their +// start/step/stride maps independently. See getIndexExprsJoinedMap for more +// details. +static wave::WaveIndexMappingAttr +getIndexExprsJoinMappings(wave::WaveIndexMappingAttr lhs, + wave::WaveIndexMappingAttr rhs) { + // Collect all unique symbol names from both index mappings in order. + llvm::SmallVector allSymbols; + llvm::SetVector lhsSymbols(llvm::from_range, + lhs.getSymbols()); + llvm::SetVector rhsSymbols(llvm::from_range, + rhs.getSymbols()); + wave::aggregateAllSymbols(llvm::ArrayRef{lhsSymbols, rhsSymbols}, allSymbols); + llvm::SetVector disallowedSymbols = + llvm::set_intersection(lhsSymbols, rhsSymbols); + SmallVector disallowedSymbolsPositions; + for (unsigned i = 0, e = allSymbols.size(); i < e; ++i) { + if (disallowedSymbols.contains(allSymbols[i])) + disallowedSymbolsPositions.push_back(i); } + + FailureOr joinedStart = getIndexExprsJoinedMap( + lhs.getStart(), rhs.getStart(), lhsSymbols.getArrayRef(), + rhsSymbols.getArrayRef(), allSymbols, disallowedSymbolsPositions); + if (failed(joinedStart)) + return nullptr; + FailureOr joinedStep = getIndexExprsJoinedMap( + lhs.getStep(), rhs.getStep(), lhsSymbols.getArrayRef(), + rhsSymbols.getArrayRef(), allSymbols, disallowedSymbolsPositions); + if (failed(joinedStep)) + return nullptr; + FailureOr joinedStride = getIndexExprsJoinedMap( + lhs.getStride(), rhs.getStride(), lhsSymbols.getArrayRef(), + rhsSymbols.getArrayRef(), allSymbols, disallowedSymbolsPositions); + if (failed(joinedStride)) + return nullptr; + + return wave::WaveIndexMappingAttr::get( + lhs.getContext(), allSymbols, *joinedStart, *joinedStep, *joinedStride); } wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( @@ -606,9 +836,11 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( if (lhs.value == rhs.value) return lhs; + // Top is saturating. if (lhs.isTop() || rhs.isTop()) return top(); + // Only named symbols may be ignored. llvm::StringSet<> ignoredRhsSymbolNames; for (mlir::Attribute attr : ignoredRhsSymbols) { auto symbolAttr = llvm::dyn_cast(attr); @@ -617,6 +849,7 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( ignoredRhsSymbolNames.insert(symbolAttr.getName()); } + // Even if LHS is bottom, we still need to filter out ignored symbols. if (lhs.isBottom()) { if (ignoredRhsSymbols.empty() || rhs.isBottom()) return rhs; @@ -637,6 +870,7 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( mlir::DictionaryAttr lhsValue = lhs.getConcreteValue(); mlir::DictionaryAttr rhsValue = rhs.getConcreteValue(); + // Join specific values per symbol. llvm::DenseMap result; for (mlir::NamedAttribute namedAttr : lhsValue) { result[namedAttr.getName()] = namedAttr.getValue(); @@ -652,105 +886,19 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( if (inserted) continue; - // Actually join otherwise. + // The symbol has a mapping on both LHS and RHS, join them. auto lhsValue = llvm::cast(it->getSecond()); auto rhsValue = llvm::cast(namedAttr.getValue()); if (lhsValue == rhsValue) continue; - auto isThreadDependent = [&](wave::WaveIndexMappingAttr val) -> bool { - llvm::SmallVector threadLikeSymbolPositions; - getThreadLikeSymbolPositions(val.getSymbols(), threadLikeSymbolPositions); - return llvm::any_of( - llvm::ArrayRef{val.getStart(), val.getStep(), val.getStride()}, - [&](mlir::AffineMap map) { - return llvm::any_of(threadLikeSymbolPositions, [&](unsigned pos) { - return map.isFunctionOfSymbol(pos); - }); - }); - }; - - // If both are thread-dependent or thread-independent, the only acceptable - // join is when they are equal, which was handled above. - bool lhsIsThreadDependent = isThreadDependent(lhsValue); - bool rhsIsThreadDependent = isThreadDependent(rhsValue); - if (!(lhsIsThreadDependent ^ rhsIsThreadDependent)) - return top(); - - wave::WaveIndexMappingAttr threadDependentMapping = - lhsIsThreadDependent ? lhsValue : rhsValue; - wave::WaveIndexMappingAttr threadIndependentMapping = - lhsIsThreadDependent ? rhsValue : lhsValue; - - // Collect all unique symbol names from both index mappings in order. - llvm::SmallVector allSymbols; - llvm::ArrayRef threadDependentSymbols = - threadDependentMapping.getSymbols(); - llvm::ArrayRef threadIndependentSymbols = - threadIndependentMapping.getSymbols(); - wave::aggregateAllSymbols( - llvm::ArrayRef{threadIndependentSymbols, threadDependentSymbols}, - allSymbols); - - mlir::AffineMap threadDependentStart = alignMapSymbols( - threadDependentMapping.getStart(), threadDependentSymbols, allSymbols); - mlir::AffineMap threadIndependentStart = - alignMapSymbols(threadIndependentMapping.getStart(), - threadIndependentSymbols, allSymbols); - - mlir::AffineMap threadDependentStep = alignMapSymbols( - threadDependentMapping.getStep(), threadDependentSymbols, allSymbols); - mlir::AffineMap threadIndependentStep = - alignMapSymbols(threadIndependentMapping.getStep(), - threadIndependentSymbols, allSymbols); - - mlir::AffineMap threadDependentStride = alignMapSymbols( - threadDependentMapping.getStride(), threadDependentSymbols, allSymbols); - mlir::AffineMap threadIndependentStride = - alignMapSymbols(threadIndependentMapping.getStride(), - threadIndependentSymbols, allSymbols); - - // Subtract the thread-independent from thread-dependent for each. - auto subtractMaps = [&](mlir::AffineMap a, - mlir::AffineMap b) -> mlir::AffineMap { - // Assert there is only one result expression in each map. - assert(a.getNumResults() == 1 && - "expected a single result expression in affine map 'a'"); - assert(b.getNumResults() == 1 && - "expected a single result expression in affine map 'b'"); - mlir::AffineExpr subtracted = a.getResult(0) - b.getResult(0); - return mlir::AffineMap::get(a.getNumDims(), a.getNumSymbols(), subtracted, - ctx); - }; - mlir::AffineMap newStart = - subtractMaps(threadDependentStart, threadIndependentStart); - mlir::AffineMap newStep = - subtractMaps(threadDependentStep, threadIndependentStep); - mlir::AffineMap newStride = - subtractMaps(threadDependentStride, threadIndependentStride); - - llvm::SmallVector threadLikeSymbolPositions; - getThreadLikeSymbolPositions(allSymbols, threadLikeSymbolPositions); - auto isOnlyThreadDependent = [&](mlir::AffineMap map) { - mlir::WalkResult walkResult = - map.getResult(0).walk([&](mlir::AffineExpr expr) { - auto symExpr = llvm::dyn_cast(expr); - if (!symExpr) - return mlir::WalkResult::advance(); - if (!llvm::is_contained(threadLikeSymbolPositions, - symExpr.getPosition())) - return mlir::WalkResult::interrupt(); - return mlir::WalkResult::advance(); - }); - return !walkResult.wasInterrupted(); - }; - - if (!isOnlyThreadDependent(newStart) || !isOnlyThreadDependent(newStep) || - !isOnlyThreadDependent(newStride)) - return top(); + wave::WaveIndexMappingAttr joinedMapping = + getIndexExprsJoinMappings(lhsValue, rhsValue); + if (!joinedMapping) + return IndexExprsLatticeStorage::top(); - result[namedAttr.getName()] = threadDependentMapping; + result[namedAttr.getName()] = joinedMapping; } return IndexExprsLatticeStorage(mlir::DictionaryAttr::get( ctx, llvm::map_to_vector(result, [](auto &&pair) { diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 69849e3e1..c29c24bd0 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -94,6 +94,30 @@ static void printSingleSymbol(mlir::OpAsmPrinter &printer, mlir::Operation *, // AllocateOp //----------------------------------------------------------------------------- +// Index expressions don't propagate through allocations. + +llvm::FailureOr +wave::AllocateOp::propagateIndexExprsForward( + llvm::ArrayRef /*operands*/, + llvm::MutableArrayRef /*results*/, + wave::EmitErrorFn /*emitError*/) { + return mlir::ChangeResult::NoChange; +} + +llvm::FailureOr +wave::AllocateOp::propagateIndexExprsBackward( + llvm::MutableArrayRef /*operands*/, + llvm::ArrayRef /*results*/, + wave::EmitErrorFn /*emitError*/) { + return mlir::ChangeResult::NoChange; +} + +llvm::LogicalResult wave::AllocateOp::setIndexFromLattices( + llvm::ArrayRef /*operands*/, + llvm::ArrayRef /*results*/) { + return llvm::success(); +} + llvm::LogicalResult wave::AllocateOp::verify() { bool hasParent = getParent() != Value(); bool hasOffset = getOffset() != std::nullopt; @@ -218,6 +242,27 @@ void wave::IterateOp::getSuccessorRegions( getLoopBody()->getArguments().drop_back(getCaptures().size()))); } +llvm::FailureOr wave::IterateOp::propagateIndexExprsForward( + llvm::ArrayRef /*operands*/, + llvm::MutableArrayRef /*results*/, + wave::EmitErrorFn /*emitError*/) { + llvm_unreachable("IterateOp should be handled by control flow interfaces"); +} + +llvm::FailureOr +wave::IterateOp::propagateIndexExprsBackward( + llvm::MutableArrayRef /*operands*/, + llvm::ArrayRef /*results*/, + wave::EmitErrorFn /*emitError*/) { + llvm_unreachable("IterateOp should be handled by control flow interfaces"); +} + +llvm::LogicalResult wave::IterateOp::setIndexFromLattices( + llvm::ArrayRef operands, + llvm::ArrayRef resultExprs) { + return detail::identitySetIndexFromLattices(*this, operands, resultExprs); +} + mlir::LogicalResult wave::IterateOp::verify() { mlir::TypeRange iterArgTypes = getIterArgs().getTypes(); mlir::TypeRange resultTypes = getResultTypes(); @@ -840,107 +885,6 @@ static llvm::LogicalResult populateMmaIndexingExpr( } } -// Get the (indexed) affine symbol expression corresponding to the given symbol, -// insert it into the list if it isn't already present. -static mlir::AffineExpr -getOrInsertSymbolExpr(mlir::Attribute symbol, - llvm::SmallVectorImpl &symbolNames) { - auto it = llvm::find(symbolNames, symbol); - unsigned position = [&] { - if (it != symbolNames.end()) - return static_cast(std::distance(symbolNames.begin(), it)); - symbolNames.push_back(symbol); - return static_cast(symbolNames.size() - 1); - }(); - return mlir::getAffineSymbolExpr(position, symbol.getContext()); -} - -// Create an index mapping induced by the given constraint. Combine it with the -// base index mapping if provided. Call `applyConstraintGeneric` if the -// constraint is only available as an Attribute. -template -static wave::WaveIndexMappingAttr -applyConstraint(ConstraintAttrT constraint, - wave::WaveIndexMappingAttr baseMapping = nullptr) { - static_assert(llvm::is_one_of(), - "unsupported constraint type for applyConstraint"); - - llvm::SmallVector symbols = - llvm::map_to_vector(constraint.getTileSize().getSymbols(), - [](mlir::Attribute symbol) { return symbol; }); - - mlir::AffineExpr symbolExpr; - mlir::MLIRContext *context = constraint.getContext(); - - if constexpr (std::is_same_v) { - // TODO: we should just use workgroup attributes in expressions directly. - wave::WaveIndexSymbol symbol = [&] { - switch (constraint.getWorkgroupDim().getValue()) { - case wave::WaveWorkgroupDim::X: - return wave::WaveIndexSymbol::WORKGROUP_0; - case wave::WaveWorkgroupDim::Y: - return wave::WaveIndexSymbol::WORKGROUP_1; - case wave::WaveWorkgroupDim::Z: - return wave::WaveIndexSymbol::WORKGROUP_2; - } - llvm::report_fatal_error("unsupported workgroup dimension"); - }(); - symbolExpr = getOrInsertSymbolExpr( - wave::WaveIndexSymbolAttr::get(context, symbol), symbols); - - } else if constexpr (std::is_same_v) { - symbolExpr = getOrInsertSymbolExpr(constraint.getDim(), symbols); - } - - assert(constraint.getTileSize().getMap().getNumResults() == 1 && - "expected a single result expression in affine map"); - mlir::AffineMap map = mlir::AffineMap::get( - /*dimCount=*/0, symbols.size(), - symbolExpr * constraint.getTileSize().getMap().getResult(0)); - if (baseMapping == nullptr) - return wave::WaveIndexMappingAttr::get( - context, symbols, map, mlir::AffineMap::getConstantMap(1, context), - mlir::AffineMap::getConstantMap(1, context)); - - llvm::SmallVector allSymbols; - wave::aggregateAllSymbols( - std::initializer_list>{ - baseMapping.getSymbols(), symbols}, - allSymbols); - - mlir::AffineMap baseStart = alignMapSymbols( - baseMapping.getStart(), baseMapping.getSymbols(), allSymbols); - mlir::AffineMap baseStep = alignMapSymbols( - baseMapping.getStep(), baseMapping.getSymbols(), allSymbols); - mlir::AffineMap baseStride = alignMapSymbols( - baseMapping.getStride(), baseMapping.getSymbols(), allSymbols); - map = alignMapSymbols(map, symbols, allSymbols); - map = mlir::AffineMap::get(/*dimCount=*/0, allSymbols.size(), - baseStart.getResult(0) + map.getResult(0)); - return wave::WaveIndexMappingAttr::get(context, allSymbols, map, baseStep, - baseStride); -} - -// Create an index mapping induced by the given constraint. Combine it with the -// base index mapping if provided. Call `applyConstraint` if the specific kind -// of constraint is already known. -static wave::WaveIndexMappingAttr -applyConstraintGeneric(mlir::Attribute constraint, - wave::WaveIndexMappingAttr baseMapping = nullptr) { - return llvm::TypeSwitch( - constraint) - .Case( - [&](auto constraint) { - // This double dispatching is necessary in absence of interfaces to - // dispatch to a class method based on a specific type. - return applyConstraint(constraint, baseMapping); - }) - .Default([&](mlir::Attribute constraint) { return nullptr; }); -} - /// Create per-symbol thread-independent index expressions for `indexingSymbols` /// given constraints on them and put them into `symbolMappings` as named pairs /// (symbol, index mapping attribute). Thread-independent means affected by @@ -969,7 +913,9 @@ static void mixInThreadIndependentConstraints( wave::WaveIndexMappingAttr mapping = llvm::cast(mappingIt->getValue()); for (mlir::Attribute constraint : it->second) { - mapping = applyConstraintGeneric(constraint, mapping); + // Tiling constraints are handled separately in loop propagation. + if (!isa(constraint)) + mapping = wave::applyConstraintGeneric(constraint, mapping); } mappingIt->setValue(mapping); } diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index dbe81c2d8..f163bb8ac 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -9,8 +9,10 @@ #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "water/Dialect/Wave/IR/IndexExpr.h" #include "water/Dialect/Wave/IR/WaveAttrs.h" #include "water/Dialect/Wave/IR/WaveDialect.h" #include "water/Dialect/Wave/IR/WaveInterfaces.h" @@ -19,6 +21,7 @@ #include "water/Dialect/Wave/Transforms/Passes.h" #include "water/Dialect/Wave/Transforms/Utils.h" #include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" @@ -862,7 +865,6 @@ class PrintNoRegions { void print(llvm::raw_ostream &os) const { operation->print(os, mlir::OpPrintingFlags().skipRegions()); - os << "\n"; } private: @@ -966,11 +968,40 @@ class IndexExprsForwardAnalysis if (!llvm::isa(value.getType())) continue; + LDBG() << "setting block argument lattice " << value + << " from " << PrintNoRegions(op) << " to bottom"; unsafeSet(getLatticeElement(value), IndexExprsLatticeStorage::bottom()); } } } + + if (auto iterateOp = llvm::dyn_cast(op)) { + // Set lattices of captured block arguments to the relevant tiling + // constraint, it will be then propagated by joining with + // expressions induced by other constraints. + wave::WaveSymbolAttr iterSymbolAttr = iterateOp.getIterator(); + llvm::SmallVector symbolConstraints = + initObject->symbolConstraints.lookup(iterSymbolAttr); + auto it = llvm::find_if( + symbolConstraints, llvm::IsaPred); + if (it != symbolConstraints.end()) { + wave::TilingConstraintAttr tilingConstraint = + llvm::cast(*it); + for (mlir::Value capture : iterateOp.getCaptureBlockArgs()) { + if (!llvm::isa(capture.getType())) + continue; + auto dict = mlir::DictionaryAttr::get( + iterSymbolAttr.getContext(), + {{iterSymbolAttr.getName(), + wave::applyConstraint(tilingConstraint)}}); + LDBG() << "setting iterate block argument lattice " << capture + << " from " << PrintNoRegions(iterateOp) << " to " + << dict; + unsafeSet(getLatticeElement(capture), dict); + } + } + } return llvm::success(); }); if (walkResult.wasInterrupted()) @@ -1003,6 +1034,8 @@ class IndexExprsForwardAnalysis lattice->join(initialized ? IndexExprsLatticeStorage::top() : IndexExprsLatticeStorage::bottom())); + if (initialized) + LDBG() << "top fixpoint for " << lattice->getAnchor(); } llvm::LogicalResult @@ -1011,21 +1044,25 @@ class IndexExprsForwardAnalysis llvm::ArrayRef results) override { LLVM_DEBUG({ - LDBG() << "visiting operation " << PrintNoRegions(op); - LDBG() << " Operands lattices:"; + LDBG() << "visiting operation forward " << PrintNoRegions(op); + LDBG() << " operand lattices:"; for (auto [i, operand] : llvm::enumerate(operands)) { - LDBG() << " operand #" << i << ": "; - operand->getValue().print(LDBG_STREAM); - LDBG() << ""; // This will generate a newline. + LDBG() << " operand #" << i << ": " << *operand; } // Print all result lattices. - LDBG() << " Results lattices:"; + LDBG() << " result lattices:"; for (auto [i, result] : llvm::enumerate(results)) { - LDBG() << " result #" << i << ": "; - result->getValue().print(LDBG_STREAM); - LDBG() << ""; // This will generate a newline. + LDBG() << " result #" << i << ": " << *result; } }); + auto scope = llvm::make_scope_exit([&] { + LLVM_DEBUG({ + LDBG() << " updated result lattices:"; + for (auto [i, result] : llvm::enumerate(results)) { + LDBG() << " result #" << i << ": " << *result; + } + }); + }); // Check if the operation implements the interface. if (!llvm::isa(op)) { @@ -1061,11 +1098,42 @@ class IndexExprsForwardAnalysis for (auto &&[resultLattice, lattice] : llvm::zip_equal(resultLattices, results)) { + // In release mode, just set the lattice value instead of calling join. + // The interface should have returned the correctly joined lattice and we + // don't want to re-join it and don't need the expensive check of the + // lattice direction. +#ifndef NDEBUG propagateIfChanged(lattice, lattice->join(resultLattice)); +#else + unsafeSet(lattice, resultLattice); +#endif } return llvm::success(); } + void + visitNonControlFlowArguments(mlir::Operation *op, + const mlir::RegionSuccessor &successor, + llvm::ArrayRef argLattices, + unsigned firstIndex) override { + auto iterateOp = llvm::dyn_cast(op); + if (!iterateOp) + return; + + LDBG() << "visiting " << PrintNoRegions(iterateOp); + + for (auto &&[capture, lattice] : llvm::zip_equal( + iterateOp.getCaptures(), + argLattices.take_back(iterateOp.getCaptures().size()))) { + const IndexExprsLattice *captureLattice = + getLatticeElementFor(getProgramPointBefore(iterateOp), capture); + LDBG() << "captured lattice: " << *captureLattice; + LDBG() << "block lattice: " << *lattice; + propagateIfChanged(lattice, lattice->join(captureLattice->getValue())); + LDBG() << "new block lattice: " << *lattice; + } + } + private: bool initialized = false; wave::OverrideInitializationFn overrideInitialization; @@ -1158,6 +1226,24 @@ class IndexExprsBackwardAnalysis void visitBranchOperand(mlir::OpOperand &opOperand) override { if (!llvm::isa(opOperand.get().getType())) return; + + // Captures of the iterate need to be propagated from the corresponding + // block arguments manually without the tiling constraint. + if (auto iterateOp = + llvm::dyn_cast(opOperand.getOwner())) { + unsigned position = opOperand.getOperandNumber(); + mlir::Value blockArgument = + iterateOp.getLoopBody()->getArgument(position); + const IndexExprsLattice *blockArgLattice = + getLatticeElement(blockArgument); + IndexExprsLattice *lattice = getLatticeElement(opOperand.get()); + IndexExprsLatticeStorage joined = IndexExprsLatticeStorage::join( + lattice->getValue(), blockArgLattice->getValue(), + iterateOp.getIterator()); + unsafeSet(lattice, joined); + return; + } + setToExitState(getLatticeElement(opOperand.get())); } @@ -1178,6 +1264,8 @@ class IndexExprsBackwardAnalysis lattice->join(initialized ? IndexExprsLatticeStorage::top() : IndexExprsLatticeStorage::bottom())); + if (initialized) + LDBG() << "top fixpoint (backward) for " << lattice->getAnchor(); } llvm::LogicalResult @@ -1185,20 +1273,24 @@ class IndexExprsBackwardAnalysis llvm::ArrayRef operands, llvm::ArrayRef results) override { LLVM_DEBUG({ - LDBG() << "visiting operation backward " << PrintNoRegions(op) << "\n"; - LDBG() << " Operands lattices:\n"; + LDBG() << "visiting operation backward " << PrintNoRegions(op); + LDBG() << " operand lattices:"; for (auto [i, operand] : llvm::enumerate(operands)) { - LDBG() << " operand #" << i << ": "; - operand->getValue().print(llvm::dbgs()); - LDBG() << "\n"; + LDBG() << " operand #" << i << ": " << *operand; } - LDBG() << " Results lattices:\n"; + LDBG() << " results lattices:"; for (auto [i, result] : llvm::enumerate(results)) { - LDBG() << " result #" << i << ": "; - result->getValue().print(llvm::dbgs()); - LDBG() << "\n"; + LDBG() << " result #" << i << ": " << *result; } }); + auto scope = llvm::make_scope_exit([&] { + LLVM_DEBUG({ + LDBG() << " updated operand lattices:"; + for (auto [i, operand] : llvm::enumerate(operands)) { + LDBG() << " operand #" << i << ": " << *operand; + } + }); + }); // Check if the operation implements the interface. if (!llvm::isa(op)) { @@ -1234,7 +1326,15 @@ class IndexExprsBackwardAnalysis for (auto &&[operandLattice, lattice] : llvm::zip_equal(operandLattices, operands)) { + // In release mode, just set the lattice value instead of calling join. + // The interface should have returned the correctly joined lattice and we + // don't want to re-join it and don't need the expensive check of the + // lattice direction. +#ifndef NDEBUG propagateIfChanged(lattice, lattice->join(operandLattice)); +#else + unsafeSet(lattice, operandLattice); +#endif } return llvm::success(); } @@ -1256,6 +1356,10 @@ class InferIndexExprsPass getArgument()))) return signalPassFailure(); + mlir::IRRewriter rewriter(&getContext()); + getOperation()->walk( + [&](wave::IterateOp iterateOp) { iterateOp.makeIsolated(rewriter); }); + mlir::SymbolTableCollection symbolTable; mlir::DataFlowConfig config; config.setInterprocedural(false); @@ -1273,6 +1377,10 @@ class InferIndexExprsPass wave::setWaveIndexExprAnalysisResults(getOperation(), solver))) return signalPassFailure(); + getOperation()->walk([&](wave::IterateOp iterateOp) { + iterateOp.makeNonIsolated(rewriter); + }); + if (llvm::failed(wave::setNormalFormPassPostcondition( wave::WaveNormalForm::IndexExprsSpecified, getOperation()))) return signalPassFailure(); diff --git a/water/python/WaterExtensionNanobind.cpp b/water/python/WaterExtensionNanobind.cpp index 43f43e5a7..814103117 100644 --- a/water/python/WaterExtensionNanobind.cpp +++ b/water/python/WaterExtensionNanobind.cpp @@ -29,6 +29,9 @@ NB_MODULE(_waterDialects, m) { mlirDialectHandleLoadDialect(h, context); }, nb::arg("context").none() = nb::none(), nb::arg("load") = true); + d.def( + "register_passes", []() { mlirWaveDialectRegisterPasses(); }, + "Registers the wave dialect passes."); // Export dialect constants d.attr("WAVE_CONSTRAINTS_ATTR_NAME") = mlirWaveDialectConstraintsAttrName; @@ -52,7 +55,32 @@ NB_MODULE(_waterDialects, m) { }, nb::arg("cls"), nb::arg("symbolName"), nb::arg("context") = nb::none(), - "Gets a wave.WaveSymbolAttr from parameters."); + "Gets a wave.WaveSymbolAttr from parameters.") + .def_property_readonly("name", [](MlirAttribute self) { + return mlirWaveSymbolAttrGetName(self); + }); + + //===---------------------------------------------------------------------===// + // WaveIterSymbolAttr + //===---------------------------------------------------------------------===// + + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + d, "WaveIterSymbolAttr", mlirAttributeIsAWaveIterSymbolAttr, + mlirWaveIterSymbolAttrGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, const std::string &symbolName, + MlirContext context) { + MlirStringRef symbolNameStrRef = + mlirStringRefCreate(symbolName.data(), symbolName.size()); + return cls(mlirWaveIterSymbolAttrGet(context, symbolNameStrRef)); + }, + nb::arg("cls"), nb::arg("symbolName"), + nb::arg("context") = nb::none(), + "Gets a wave.WaveIterSymbolAttr from parameters.") + .def_property_readonly("name", [](MlirAttribute self) { + return mlirWaveIterSymbolAttrGetName(self); + }); //===---------------------------------------------------------------------===// // WaveIndexSymbolAttr @@ -124,7 +152,28 @@ NB_MODULE(_waterDialects, m) { }, nb::arg("cls"), nb::arg("symbols"), nb::arg("start"), nb::arg("step"), nb::arg("stride"), nb::arg("context") = nb::none(), - "Gets a wave.WaveIndexMappingAttr from a list of symbol attributes."); + "Gets a wave.WaveIndexMappingAttr from a list of symbol attributes.") + .def_property_readonly("start", + [](MlirAttribute self) { + return mlirWaveIndexMappingAttrGetStart(self); + }) + .def_property_readonly("step", + [](MlirAttribute self) { + return mlirWaveIndexMappingAttrGetStep(self); + }) + .def_property_readonly("stride", + [](MlirAttribute self) { + return mlirWaveIndexMappingAttrGetStride(self); + }) + .def_property_readonly("symbols", [](MlirAttribute self) { + std::vector symbols; + intptr_t numSymbols = mlirWaveIndexMappingAttrGetNumSymbols(self); + symbols.reserve(numSymbols); + for (intptr_t i = 0; i < numSymbols; i++) { + symbols.push_back(mlirWaveIndexMappingAttrGetSymbol(self, i)); + } + return symbols; + }); //===---------------------------------------------------------------------===// // WaveHyperparameterAttr diff --git a/water/test/Dialect/Wave/attr-type-invalid.mlir b/water/test/Dialect/Wave/attr-type-invalid.mlir index 9d72bc7c4..702ffcf26 100644 --- a/water/test/Dialect/Wave/attr-type-invalid.mlir +++ b/water/test/Dialect/Wave/attr-type-invalid.mlir @@ -22,3 +22,8 @@ module attributes {wave.elements_per_thread = "abc"} {} // expected-error @below {{unexpected wave dialect attribute "wave.unexpected"}} module attributes {wave.unexpected = 42} {} + +// ----- + +// expected-error @below {{symbols names starting with '_' are reserved for internal use}} +module attributes {wave_test.symbol = #wave.symbol<"_A">} diff --git a/water/test/Dialect/Wave/attr-type.mlir b/water/test/Dialect/Wave/attr-type.mlir index 2270918a0..76d7b117c 100644 --- a/water/test/Dialect/Wave/attr-type.mlir +++ b/water/test/Dialect/Wave/attr-type.mlir @@ -17,3 +17,9 @@ func.func private @address_space_full() -> !wave.tensor func.func private @address_space_default() -> !wave.tensor> + +// CHECK: #wave, #wave.symbol<"B">] -> (_Iter_A + B, 2, 2)> +func.func private @iter_symbol_in_mapping() attributes { wave_test.index = #wave, #wave.symbol<"B">] -> (_Iter_A + B, 2, 2)>} + +// CHECK: #wave.expr_list<[#wave.iter<"A">, #wave.symbol<"B">] -> (_Iter_A + B)> +func.func private @iter_symbol_in_expr() attributes { wave_test.index = #wave.expr_list<[#wave.iter<"A">, #wave.symbol<"B">] -> (_Iter_A + B)>} diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index f784ea35f..2952c484a 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -98,7 +98,7 @@ func.func @iterate_mismatching_results(%arg0: !wave.tensor<[@A] of f32>, %arg1: // must provide the full triple (start, step, stride) func.func @index_attr_wrong_attr_type(%arg0: f32) { - // expected-error @below {{custom op 'wave.register' expected symbol names to be either a WaveSymbolAttr or WaveIndexSymbolAttr}} + // expected-error @below {{expected symbol names to be one of WaveSymbolAttr, WaveIndexSymbolAttr or WaveIterSymbolAtt}} wave.register %arg0 index [{X : [#wave.workgroup_dim] -> (WG0)}] : !wave.tensor<[@M] of f32, > return } @@ -142,6 +142,14 @@ func.func @index_attr_wrong_value_type(%arg0: f32) { // ----- +func.func @index_attr_iter_not_allowed(%arg0: f32) { + // expected-error @below {{index expression uses iterator symbol M which is not defined by any parent op}} + wave.register %arg0 index [{M : [#wave.iter<"M">] -> (0, 1, 1)}] : !wave.tensor<[@M] of f32, > + return +} + +// ----- + func.func @mismatch_shape_binary(%lhs: !wave.tensor<[@A, @B] of f32>, %rhs: !wave.tensor<[@B, @C] of f32>) { // expected-error @below {{expected operand #1 dimension #0 (#wave.symbol<"B">) to match operand #0 dimension #0 (#wave.symbol<"A">)}} wave.add %lhs, %rhs : (!wave.tensor<[@A, @B] of f32>, !wave.tensor<[@B, @C] of f32>) -> !wave.tensor diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index 282278b11..4022b6b82 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -71,6 +71,16 @@ func.func @iterate(%input0: !wave.tensor, %input1: !wave.tensor, !wave.tensor } +func.func @using_iter_symbol(%arg0: f32) { + %0 = wave.register %arg0 : !wave.tensor<[@M] of f32, > + wave.iterate @M iter_args(%0) { + ^bb0(%arg1: !wave.tensor<[@M] of f32, >): + wave.register %arg0 index [{M : [#wave.iter<"M">] -> (0, 1, 1)}] : !wave.tensor<[@M] of f32, > + wave.yield %arg1 : !wave.tensor<[@M] of f32, > + } : (!wave.tensor<[@M] of f32, >) -> !wave.tensor + return +} + // CHECK-LABEL: @register func.func @register() { %0 = arith.constant 0.0 : f32 diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index a79649057..e4f385a8a 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -14,6 +14,8 @@ import wave_lang.kernel.lang as tkl from wave_lang.kernel._support.dtype import DataType +from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect +from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.logging import get_logger from ..._support.indexing import IndexSequence, IndexSymbol @@ -259,6 +261,86 @@ def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]): ), f"Vector shapes not set for node {custom.fx_node}: {custom}" +def _set_water_id(trace: CapturedTrace): + for node in trace.walk(lambda x: x): + setattr(node, "_water_id", str(id(node))) + + +def _reset_water_id(trace: CapturedTrace): + for node in trace.walk(lambda x: x): + delattr(node, "_water_id") + + +def _index_diff_is_zero( + index1: dict[IndexSymbol, IndexSequence], index2: dict[IndexSymbol, IndexSequence] +) -> dict[IndexSymbol, IndexSequence]: + def f(seq1: IndexSequence, seq2: IndexSequence) -> bool: + start = sympy.simplify(seq1.start - seq2.start) + size = sympy.simplify(seq1.size - seq2.size) + stride = sympy.simplify(seq1.stride - seq2.stride) + if start != 0: + print(f"Start difference: {start}") + if size != 0: + print(f"Size difference: {size}") + if stride != 0: + print(f"Stride difference: {stride}") + return start == 0 and size == 0 and stride == 0 + + return index1.keys() == index2.keys() and all( + f(seq, index2[dim]) for dim, seq in index1.items() + ) + + +def _check_water_indices(trace: CapturedTrace, inferred: dict[str, IndexSequence]): + for node in trace.walk(lambda x: x): + water_id = getattr(node, "_water_id") + custom = get_custom(node) + if isinstance(custom, (Placeholder, Output)): + continue + if water_id not in inferred: + raise RuntimeError( + f"Node {get_custom(node)} with id {water_id} not found in water-inferred index expressions." + ) + inferred_index = inferred[water_id].get("index", None) + if not getattr(node, "index", None): + assert isinstance( + custom, NestedRegionOp + ), "Index may only be missing for NestedRegionOps." + continue + # Skip GetResult because they are special-cased in Python propagation, + # making them have incorrect indexes in dataflow sense. + if isinstance(custom, GetResult): + continue + if node.index != inferred_index and not _index_diff_is_zero( + node.index, inferred_index + ): + raise RuntimeError( + f"Index for node {get_custom(node)}, {get_custom(node).index} does not match inferred index {inferred_index}" + ) + + +# pipeline: set water ids, run water pass and record, run actual pass, verify + + +def set_node_indices_water_checked( + trace: CapturedTrace, + constraints: list[Constraint], + options: WaveCompileOptions, + print_ir_before: Sequence[str] = [], + print_ir_after: Sequence[str] = [], +): + _set_water_id(trace) + # TODO: make sure _water_id gets printed as an attribute + # TODO: recover an `extras` field from here, which would contain a dictionary between _water_id and the inferred index expressions for that node + _, diagnostics, inferred_attributes = emit_wave_dialect(trace, constraints, options) + if diagnostics: + raise RuntimeError(f"Water indices check failed: {diagnostics}") + set_node_indices(trace, constraints, print_ir_before, print_ir_after) + # TODO: use the extras field to fetch the inferred indices + _check_water_indices(trace, inferred_attributes) + _reset_water_id(trace) + + def set_node_indices( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index 382bf06e0..49f88b215 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -75,6 +75,7 @@ class WaveCompileOptions: ) use_local_scope: bool = False use_water_leak_check: bool | str = False # If string, check the given IR instead. + check_water_analysis: bool = False enforce_locations: bool = True drop_debug_info_before_mlir: bool = True diff --git a/wave_lang/kernel/wave/mlir_converter/mlir_converter.py b/wave_lang/kernel/wave/mlir_converter/mlir_converter.py index 71b67890d..797512fa2 100644 --- a/wave_lang/kernel/wave/mlir_converter/mlir_converter.py +++ b/wave_lang/kernel/wave/mlir_converter/mlir_converter.py @@ -20,6 +20,7 @@ import subprocess import sys from pathlib import Path +from typing import Any import dill from wave_lang.kernel._support.tracing import CapturedTrace from wave_lang.kernel.wave.compile_options import WaveCompileOptions @@ -30,10 +31,9 @@ def emit_wave_dialect( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions, - *, test_diagnostic_emission: bool = False, pipeline: str = "", -) -> tuple[str, list[str]]: +) -> tuple[str, list[str], dict[str, dict[str, Any]]]: """Emit Wave MLIR by sending the pickled trace and options to the emitter. The `subs` field of options is the only option used during emission. If @@ -65,6 +65,18 @@ def emit_wave_dialect( stderr=subprocess.PIPE, ) + assert ( + not options.check_water_analysis or not pipeline + ), "Cannot check water analysis and use a pipeline" + if options.check_water_analysis: + pipeline = """ +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + transform.apply_registered_pass "water-wave-infer-index-exprs" to %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +}""" + output, err = proc.communicate( dill.dumps( { @@ -93,9 +105,16 @@ def emit_wave_dialect( ) from e diagnostics = unpickled.get("diagnostics") if isinstance(unpickled, dict) else None module = unpickled.get("module") if isinstance(unpickled, dict) else None + inferred_attributes = ( + unpickled.get("inferred_attributes") if isinstance(unpickled, dict) else None + ) # Preserve stderr messages. if err: print(err.decode("utf-8", errors="replace"), file=sys.stderr) - return module.decode("utf-8"), [d.decode("utf-8") for d in diagnostics] + return ( + module.decode("utf-8"), + [d.decode("utf-8") for d in diagnostics], + inferred_attributes, + ) diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index 38373b56e..6f15ca4c1 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -34,6 +34,7 @@ from wave_lang.kernel.ops.wave_ops import * from wave_lang.support.location_config import LocationCaptureLevel +from wave_lang.kernel._support.indexing import index_symbol, IndexSequence, IndexSymbol try: from water_mlir.water_mlir import ir @@ -41,24 +42,26 @@ AddOp, AllocateOp, DivOp, - ExtractSliceOp, Exp2Op, + ExtractSliceOp, + IterateOp, MmaOp, MulOp, ReadOp, RegisterOp, WriteOp, - IterateOp, YieldOp, - WaveExprListAttr, + DeviceConstraintAttr, HardwareConstraintAttr, - WorkgroupConstraintAttr, - WaveConstraintAttr, TilingConstraintAttr, - DeviceConstraintAttr, + WaveConstraintAttr, + WaveExprListAttr, WaveMmaKind, WaveMmaKindAttr, + WaveNormalForm, + WaveNormalFormAttr, WaveWorkgroupDimAttr, + WorkgroupConstraintAttr, ) from water_mlir.water_mlir.sympy_to_affine_converter import ( convert_sympy_to_affine_map, @@ -240,19 +243,177 @@ def _convert_sympy_expr_to_affine_map( ) +def _convert_affine_expr_to_sympy_expr( + expr: ir.AffineExpr, + symbol_mapping: Sequence[sympy.Symbol], +) -> sympy.Expr: + if ir.AffineConstantExpr.isinstance(expr): + return sympy.Integer(ir.AffineConstantExpr(expr).value) + if ir.AffineSymbolExpr.isinstance(expr): + return symbol_mapping[ir.AffineSymbolExpr(expr).position] + if ir.AffineAddExpr.isinstance(expr): + add_expr = ir.AffineAddExpr(expr) + return _convert_affine_expr_to_sympy_expr( + add_expr.lhs, symbol_mapping + ) + _convert_affine_expr_to_sympy_expr(add_expr.rhs, symbol_mapping) + if ir.AffineMulExpr.isinstance(expr): + mul_expr = ir.AffineMulExpr(expr) + return _convert_affine_expr_to_sympy_expr( + mul_expr.lhs, symbol_mapping + ) * _convert_affine_expr_to_sympy_expr(mul_expr.rhs, symbol_mapping) + if ir.AffineFloorDivExpr.isinstance(expr): + floor_div_expr = ir.AffineFloorDivExpr(expr) + return sympy.floor( + _convert_affine_expr_to_sympy_expr(floor_div_expr.lhs, symbol_mapping) + / _convert_affine_expr_to_sympy_expr(floor_div_expr.rhs, symbol_mapping) + ) + if ir.AffineCeilDivExpr.isinstance(expr): + ceil_div_expr = ir.AffineCeilDivExpr(expr) + return sympy.ceil( + _convert_affine_expr_to_sympy_expr(ceil_div_expr.lhs, symbol_mapping) + / _convert_affine_expr_to_sympy_expr(ceil_div_expr.rhs, symbol_mapping) + ) + if ir.AffineModExpr.isinstance(expr): + mod_expr = ir.AffineModExpr(expr) + return _convert_affine_expr_to_sympy_expr( + mod_expr.lhs, symbol_mapping + ) % _convert_affine_expr_to_sympy_expr(mod_expr.rhs, symbol_mapping) + raise ValueError(f"Unsupported affine expression: {expr} of type {type(expr)}") + + +def _convert_index_mapping_attr_to_sympy( + attr: wave.WaveIndexMappingAttr, +) -> IndexSequence: + def wrap_symbol(symbol_name: ir.Attribute) -> sympy.Symbol: + if isinstance(symbol_name, wave.WaveSymbolAttr): + return index_symbol(symbol_name.name) + elif isinstance(symbol_name, wave.WaveIterSymbolAttr): + return index_symbol("$ARG" + symbol_name.name) + elif isinstance(symbol_name, wave.WaveIndexSymbolAttr): + match symbol_name.value: + case wave.WaveIndexSymbol.WORKGROUP_0: + return index_symbol("$WG0") + case wave.WaveIndexSymbol.WORKGROUP_1: + return index_symbol("$WG1") + case wave.WaveIndexSymbol.WORKGROUP_2: + return index_symbol("$WG2") + case wave.WaveIndexSymbol.THREAD_0: + return index_symbol("$T0") + case wave.WaveIndexSymbol.THREAD_1: + return index_symbol("$T1") + case wave.WaveIndexSymbol.THREAD_2: + return index_symbol("$T2") + case wave.WaveIndexSymbol.DEVICE_DIM_0: + return sympy.Symbol("$DD0") + case wave.WaveIndexSymbol.DEVICE_DIM_1: + return index_symbol("$DD1") + case wave.WaveIndexSymbol.DEVICE_DIM_2: + return index_symbol("$DD2") + case wave.WaveIndexSymbol.GPR_NUMBER: + return index_symbol("$GPR_NUM") + case _: + raise ValueError(f"Unsupported index symbol: {symbol_name.value}") + else: + raise ValueError(f"Unsupported symbol attribute: {symbol_name}") + + symbols = list(map(wrap_symbol, attr.symbols)) + assert ( + len(attr.start.results) == 1 + ), f"Expected start map to have one expression, got {attr.start}" + assert ( + len(attr.step.results) == 1 + ), f"Expected step map to have one expression, got {attr.step}" + assert ( + len(attr.stride.results) == 1 + ), f"Expected stride map to have one expression, got {attr.stride}" + start = _convert_affine_expr_to_sympy_expr(attr.start.results[0], symbols) + step = _convert_affine_expr_to_sympy_expr(attr.step.results[0], symbols) + stride = _convert_affine_expr_to_sympy_expr(attr.stride.results[0], symbols) + return IndexSequence(start, step, stride) + + +def _convert_index_mapping_dict_to_sympy( + dict_attr: ir.DictAttr, +) -> dict[IndexSymbol, IndexSequence]: + result = {} + for named_attr in dict_attr: + key = named_attr.name + value = named_attr.attr + assert isinstance( + value, wave.WaveIndexMappingAttr + ), f"Unsupported index mapping attribute: {value}" + result[index_symbol(key)] = _convert_index_mapping_attr_to_sympy(value) + return result + + +def _make_piecewise_sequence( + *components: tuple[IndexSequence, sympy.Expr] +) -> IndexSequence: + return IndexSequence( + start=sympy.Piecewise( + *[(component[0].start, component[1]) for component in components] + ), + size=sympy.Piecewise( + *[(component[0].size, component[1]) for component in components] + ), + stride=sympy.Piecewise( + *[(component[0].stride, component[1]) for component in components] + ), + ) + + +def _convert_index_mapping_array_to_sympy( + op: ir.Operation, array_attr: ir.ArrayAttr +) -> dict[IndexSymbol, IndexSequence]: + # TODO: for some reason, isinstance(op.opview, MmaOp) is not working. Something is off with dialect loading/registration. + if op.name != "wave.mma": + assert ( + len(array_attr) == 1 + ), f"Expected exactly one index mapping attribute for non-MMA op: {op}" + return _convert_index_mapping_dict_to_sympy(array_attr[0]) + + assert ( + len(array_attr) == 4 + ), f"Expected exactly four index mapping attributes for MMA op: {op}" + lhs_index = _convert_index_mapping_dict_to_sympy(array_attr[0]) + rhs_index = _convert_index_mapping_dict_to_sympy(array_attr[1]) + acc_index = _convert_index_mapping_dict_to_sympy(array_attr[2]) + result_index = _convert_index_mapping_dict_to_sympy(array_attr[3]) + mk_symbols = set(lhs_index.keys()) + nk_symbols = set(rhs_index.keys()) + m_symbol = (mk_symbols - nk_symbols).pop() + n_symbol = (nk_symbols - mk_symbols).pop() + k_symbol = (mk_symbols.intersection(nk_symbols)).pop() + assert lhs_index[k_symbol] == rhs_index[k_symbol] + assert rhs_index[n_symbol] == acc_index[n_symbol] + assert acc_index[m_symbol] == result_index[m_symbol] + assert acc_index[n_symbol] == result_index[n_symbol] + return { + m_symbol: _make_piecewise_sequence( + (lhs_index[m_symbol], ~index_symbol("$MMA_ACC")), + (acc_index[m_symbol], index_symbol("$MMA_ACC")), + ), + n_symbol: rhs_index[n_symbol], + k_symbol: lhs_index[k_symbol], + } + + +# convert index expression attribute to the sympy equivalent + + def _preprocess_symbols( symbols: Sequence[sympy.Symbol], ) -> dict[sympy.Symbol, sympy.Symbol]: """ Preprocess symbols by: (1) adding assumptions about all symbols being positive to later enable more simplifications. - (2) replacing `$ARG` prefix of argument symbols (e.g. `ARG0`) by `_ARG` for consistency. + (2) replacing `$ARG` prefix of argument symbols (e.g. `ARG0`) by `_Iter_` to match dialect expectations. """ result = {} for sym in symbols: - # Special case: rename ARG* symbols to _ARG* + # Special case: rename $ARG* symbols to _Iter_*. if sym.name.startswith("$ARG"): - new_name = sym.name.replace("$", "_") + new_name = sym.name.replace("$ARG", "_Iter_") result[sym] = sympy.Symbol(new_name, positive=True) else: result[sym] = sympy.Symbol(sym.name, positive=True) @@ -282,6 +443,8 @@ def _symbol_name_to_attribute(name: str) -> ir.Attribute: if name in INDEX_SYMBOL_MAP: return wave.WaveIndexSymbolAttr.get(INDEX_SYMBOL_MAP[name]) + if name.startswith("_Iter_"): + return wave.WaveIterSymbolAttr.get(name.replace("_Iter_", "")) else: return wave.WaveSymbolAttr.get(name) @@ -346,7 +509,9 @@ def _build_index_mapping_dict( return ir.DictAttr.get(index_mappings) -def _attach_attributes(node: CustomOp, op: ir.Operation): +def _attach_attributes( + node: CustomOp, op: ir.Operation, known_ids: set[str] | None = None +): from wave_lang.kernel.ops.wave_ops import Iterate, MMA, get_custom from wave_lang.kernel.wave.utils.symbol_utils import get_induction_symbol @@ -403,6 +568,13 @@ def _attach_attributes(node: CustomOp, op: ir.Operation): bounds[dim.name] = wave.WaveExprListAttr.get(symbol_attrs, result) op.attributes["bounds"] = wave.WaveReadWriteBoundsAttr.get(bounds) + if water_id := getattr(node.fx_node, "_water_id", None): + op.attributes[_INTERNAL_WATER_ID_ATTR_NAME] = ir.StringAttr.get(water_id) + if known_ids is not None: + known_ids.add(water_id) + elif known_ids is not None: + raise RuntimeError(f"Water id requested but not specified for node {node}.") + def _convert_to_wave_expr_list_tuple( exprs: Sequence[sympy.Expr | int], @@ -443,6 +615,7 @@ def _emit_ops_from_graph( trace: CapturedTrace, value_map: dict[fx.Node | fx.Proxy, ir.Value], ctx: ir.Context, + known_ids: set[str] | None = None, ): # Import wave types locally to avoid clashing with iree bindings from wave_lang.kernel.ops.wave_ops import ( @@ -485,6 +658,29 @@ def _emit_ops_from_graph( f"GetResult index is higher than number of results of corresponding iterate node ({node.res_idx} vs {len(iterate_op.results)})" ) value_map[fx_node] = iterate_op.results[node.res_idx] + + # Attach IDs of `get_result` to the loop instead so we can recover them + # later because `get_result` doesn't exist in the dialect. + if known_ids is not None: + water_id = getattr(fx_node, "_water_id", None) + if water_id is None: + raise RuntimeError( + f"Water id requested for 'get_result' but not specified: {node}" + ) + known_ids.add(water_id) + current_attribute = ( + iterate_op.attributes[_INTERNAL_RESULT_WATER_IRS_ATTR_NAME] + if _INTERNAL_RESULT_WATER_IRS_ATTR_NAME in iterate_op.attributes + else ir.ArrayAttr.get( + [ir.UnitAttr.get()] * len(iterate_op.results) + ) + ) + attribute_list = list(current_attribute) + attribute_list[node.res_idx] = ir.StringAttr.get(water_id) + iterate_op.attributes[_INTERNAL_RESULT_WATER_IRS_ATTR_NAME] = ( + ir.ArrayAttr.get(attribute_list) + ) + # additional handling for this op is not needed, skip rest continue if isinstance(node, SharedMemoryBarrier): @@ -524,6 +720,8 @@ def _emit_ops_from_graph( result_types = [] result_locs = [] outputs = node.outputs() + if not isinstance(outputs, Sequence): + outputs = [outputs] for fx_output in outputs: output = get_custom(fx_output) output.infer_type() @@ -553,16 +751,20 @@ def _emit_ops_from_graph( trace, value_map, ctx, + known_ids, ) # create YieldOp YieldOp([value_map[output] for output in outputs]) elif isinstance(node, MMA): + # TODO: FIXME: need to call the propagation pass upfront if node.mma_type is None: - raise RuntimeError("MMA op missing mma_type") - mma_kind = ir.Attribute.parse( - f"#wave.mma_kind<{node.mma_type.name.lower()}>", context=ctx - ) + mma_kind = WaveMmaKindAttr.get(WaveMmaKind.F32_16x16x16_F16) + # raise RuntimeError("MMA op missing mma_type") + else: + mma_kind = ir.Attribute.parse( + f"#wave.mma_kind<{node.mma_type.name.lower()}>", context=ctx + ) mlir_op = op_builder(result_type, *mlir_operands, mma_kind) elif isinstance(node, Allocate): mlir_op = op_builder( @@ -591,7 +793,7 @@ def _emit_ops_from_graph( f"Missing support for '{node.tkw_op_name}' operation" ) - _attach_attributes(node, mlir_op.operation) + _attach_attributes(node, mlir_op.operation, known_ids) # Add results to the value map in case they are used as # operands later @@ -663,11 +865,18 @@ def _emit_wave_constraints(constraint: Constraint) -> ir.Attribute: raise NotImplementedError(f"Unsupported constraint type: {type(constraint)}") -def _flush_output(module_str: str, diagnostics: list[str]) -> None: +def _flush_output( + module_str: str, + diagnostics: list[str], + inferred_attributes: dict[str, dict[str, Any]] | None = None, +) -> None: output = dill.dumps( { "diagnostics": [d.encode("utf-8") for d in diagnostics], "module": module_str.encode("utf-8"), + "inferred_attributes": ( + inferred_attributes if inferred_attributes is not None else {} + ), } ) sys.stdout.buffer.write(output) @@ -680,7 +889,7 @@ def _create_kernel_module( constraints: list[Constraint], options: WaveCompileOptions, test_diagnostics: bool = False, -) -> tuple[ir.Module | None, list[str]]: +) -> tuple[ir.Module | None, list[str], set[str]]: """Creates an MLIR module containing the kernel function from the captured trace. Args: @@ -688,15 +897,18 @@ def _create_kernel_module( trace: Captured Wave trace to convert. constraints: List of Wave constraints to attach to the function. options: Compilation options including hyperparameters. + node_backmap: Map from hash of node to node. test_diagnostics: Whether to emit a test diagnostic Returns: - The created MLIR module, or None if creation failed. - List of diagnostic messages. + - Set of known water IDs if options require checking water analysis. """ from wave_lang.kernel.ops.wave_ops import get_custom, IterArg # type: ignore diagnostics: list[str] = [] + known_ids: set[str] | None = set() if options.check_water_analysis else None def diagnostics_handler(d): diagnostics.append(f"{d.location}: {d.message}") @@ -709,9 +921,9 @@ def diagnostics_handler(d): module = ir.Module.parse(options.override_mlir, context=ctx) except ir.MLIRError as e: diagnostics.append(str(e)) - return None, diagnostics + return None, diagnostics, known_ids else: - return module, diagnostics + return module, diagnostics, known_ids # Keep track of which emitted value stems from what node to wire # arguments correctly. @@ -783,17 +995,23 @@ def diagnostics_handler(d): ] with ir.InsertionPoint(entry_block): - _emit_ops_from_graph(trace.get_root_graph(), trace, value_map, ctx) + _emit_ops_from_graph( + trace.get_root_graph(), trace, value_map, ctx, known_ids + ) func.ReturnOp(operands_=[]) - return module, diagnostics + return module, diagnostics, known_ids + + +_INTERNAL_WATER_ID_ATTR_NAME = "_water_internal.id" +_INTERNAL_RESULT_WATER_IRS_ATTR_NAME = "_water_internal.result_ids" def _emit_from_captured_trace( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions, - pipeline: str, + pipeline: str = "", test_diagnostics=False, ) -> int: @@ -806,20 +1024,28 @@ def _emit_from_captured_trace( if enable_debug_info and not trace.location: diagnostics.append("Missing debug location for wave trace") - with ir.Context() as ctx, ( - trace.location.to_water() if trace.location else ir.Location.unknown() + with ( + ir.Context() as ctx, + trace.location.to_water() if trace.location else ir.Location.unknown(), ): ctx.allow_unregistered_dialects = False wave.register_dialect(ctx) + wave.register_passes() - module, creation_diagnostics = _create_kernel_module( + module, creation_diagnostics, known_ids = _create_kernel_module( ctx, trace, constraints, options, test_diagnostics ) diagnostics.extend(creation_diagnostics) if module is None: - _flush_output("", diagnostics) + _flush_output("", diagnostics, None) return 0 + # TODO: this should be a pass to detect the normal form... + if pipeline: + module.operation.attributes["wave.normal_form"] = WaveNormalFormAttr.get( + WaveNormalForm.AllTypesSpecified + ) + # Verify the module before transforming or printing. try: module.operation.verify() @@ -832,6 +1058,7 @@ def _emit_from_captured_trace( enable_debug_info=enable_debug_info, print_generic_op_form=True ), diagnostics, + None, ) return 0 @@ -857,8 +1084,69 @@ def _emit_from_captured_trace( except Exception as e: diagnostics.append(f"Failed to apply transform script: {e}") + print(module.operation, file=sys.stderr) + module_str = module.operation.get_asm(enable_debug_info=enable_debug_info) + print(module_str, file=sys.stderr) + + # TODO: this special-cases index attributes + inferred_attributes: dict[str, dict[str, Any]] = {id: {} for id in known_ids} + if options.check_water_analysis: + + def extractor(op: ir.Operation) -> ir.WalkResult: + attribute: ir.Attribute | None = ( + op.attributes[_INTERNAL_WATER_ID_ATTR_NAME] + if _INTERNAL_WATER_ID_ATTR_NAME in op.attributes + else None + ) + result_attribute: ir.Attribute | None = ( + op.attributes[_INTERNAL_RESULT_WATER_IRS_ATTR_NAME] + if _INTERNAL_RESULT_WATER_IRS_ATTR_NAME in op.attributes + else None + ) + if attribute is None and result_attribute is None: + return ir.WalkResult.ADVANCE + + def record_index( + attribute: ir.Attribute, + inferred_attributes: dict[str, dict[str, Any]], + ): + assert isinstance( + attribute, ir.StringAttr + ), f"Unexpected attribute type: {attribute}." + assert ( + attribute.value in inferred_attributes + ), f"Unknown water id {attribute.value}." + assert ( + "index" not in inferred_attributes[attribute.value] + ), f"Index already set for water id {attribute.value}." + assert "index" in op.attributes, f"Index not inferred for {op}." + + inferred_attributes[attribute.value].update( + { + "index": _convert_index_mapping_array_to_sympy( + op, op.attributes["index"] + ) + } + ) + + if attribute is not None: + record_index(attribute, inferred_attributes) + if result_attribute is not None: + assert isinstance( + result_attribute, ir.ArrayAttr + ), f"Unexpected attribute type: {result_attribute}." + for attribute in result_attribute: + record_index(attribute, inferred_attributes) + + return ir.WalkResult.ADVANCE + + module.operation.walk(extractor) + for water_id, inferred_attribute in inferred_attributes.items(): + if "index" not in inferred_attribute: + raise RuntimeError(f"Index not inferred for water id {water_id}.") + module_str = module.operation.get_asm(enable_debug_info=enable_debug_info) - _flush_output(module_str, diagnostics) + _flush_output(module_str, diagnostics, inferred_attributes) return 0 @@ -873,9 +1161,9 @@ def _emit_from_captured_trace( args = parser.parse_args() - trace, constraints, options, pipeline = _parse_input() + trace, constraints, options, pass_pipeline = _parse_input() sys.exit( _emit_from_captured_trace( - trace, constraints, options, pipeline, args.test_diagnostic_emission + trace, constraints, options, pass_pipeline, args.test_diagnostic_emission ) ) diff --git a/wave_lang/kernel/wave/templates/gemm.py b/wave_lang/kernel/wave/templates/gemm.py index 591ba4f97..bebab5e7a 100644 --- a/wave_lang/kernel/wave/templates/gemm.py +++ b/wave_lang/kernel/wave/templates/gemm.py @@ -55,8 +55,8 @@ def get_gemm_kernel( constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / waves_per_block[0])] - constraints += [tkw.WaveConstraint(N, BLOCK_N / waves_per_block[1])] + constraints += [tkw.WaveConstraint(M, sympy.floor(BLOCK_M / waves_per_block[0]))] + constraints += [tkw.WaveConstraint(N, sympy.floor(BLOCK_N / waves_per_block[1]))] constraints += [ tkw.HardwareConstraint(threads_per_wave=threads_per_wave, mma_type=mfma_variant) diff --git a/wave_lang/kernel/wave/wave.py b/wave_lang/kernel/wave/wave.py index 4f477fd2c..a38f9dc2e 100644 --- a/wave_lang/kernel/wave/wave.py +++ b/wave_lang/kernel/wave/wave.py @@ -52,6 +52,7 @@ # Passes from .analysis.index_sequence_analysis import ( + set_node_indices_water_checked, set_node_indices, set_post_expansion_indices, ) @@ -764,41 +765,55 @@ def finalize_indices(): def substitute_vector_shapes(): self.hardware_constraints[0].subs_vector_shapes(idxc.subs) - return [ - partial(debug_log_hoist, trace, debug_handlers), - partial(initialize_iter_args, trace), - partial(self.create_induction_vars, trace), - partial(self.initialize_reductions, trace), - finalize_indices, - substitute_vector_shapes, - partial(add_get_results, trace), - partial(infer_types, trace, self.constraints), - partial(construct_index_mapping, trace, self.constraints), - partial( - debug_log_write_replace, - trace, - self.constraints, - options, - debug_arg_info, - ), - partial( - promote_placeholders, - trace, - self.constraints, - options.reorder_allocs, - ), - partial( - set_node_indices, - trace, - self.constraints, - print_ir_before, - print_ir_after, - ), - partial(reorder_workgroups, trace, self.reordering_constraints), - partial(expand_graph, trace, self.constraints), - partial(set_post_expansion_indices, trace, self.constraints), - partial(remove_chained_getresult, trace), - ] + return ( + [ + partial(debug_log_hoist, trace, debug_handlers), + partial(initialize_iter_args, trace), + partial(self.create_induction_vars, trace), + partial(self.initialize_reductions, trace), + finalize_indices, + substitute_vector_shapes, + partial(add_get_results, trace), + partial(infer_types, trace, self.constraints), + partial(construct_index_mapping, trace, self.constraints), + partial( + debug_log_write_replace, + trace, + self.constraints, + options, + debug_arg_info, + ), + partial( + promote_placeholders, + trace, + self.constraints, + options.reorder_allocs, + ), + ] + + ( + [ + partial( + set_node_indices_water_checked, trace, self.constraints, options + ) + ] + if options.check_water_analysis + else [ + partial( + set_node_indices, + trace, + self.constraints, + print_ir_before, + print_ir_after, + ) + ] + ) + + [ + partial(reorder_workgroups, trace, self.reordering_constraints), + partial(expand_graph, trace, self.constraints), + partial(set_post_expansion_indices, trace, self.constraints), + partial(remove_chained_getresult, trace), + ] + ) def _trace_and_get_kernel_signature( self,