Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions tests/kernel/wave/infer_index_exprs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.wave.wave import LaunchableWave
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile

from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.utils.general_utils import torch_dtype_to_wave

import torch


# TODO: use the generic template, currently blocked by water not handling wave constraints.
def _get_gemm_kernel(
shape: tuple[int, int, int],
mfma_variant: MMAType,
dtype: torch.dtype = torch.float16,
block_shape: tuple[int, int, int] | None = None,
waves_per_block: tuple[int, int] | None = None,
) -> tuple[LaunchableWave, dict[tkl.IndexSymbol, tkl.IndexExpr]]:
if not block_shape:
# BLOCK_M, BLOCK_N, BLOCK_K
block_shape = (64, 64, 32)

if not waves_per_block:
# WAVE_M, WAVE_N
waves_per_block = (2, 2)

assert len(block_shape) == 3, "block_shape needs to be rank 3 for M, N, K."
assert len(waves_per_block) == 2, "waves_per_block needs to be rank 2 for M, N."

# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE
dtype = torch_dtype_to_wave(dtype)
# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

# TODO: dialect expects waves_per_block to be rank 3, so we append a 1 to the end.
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=mfma_variant,
waves_per_block=waves_per_block + (1,),
)
]

# Wave-level micro-kernel.
# Since warps are not directly addressable, there is no
# explicit notion of a warp id (like a workgroup or thread id).
# This kernel uses the input sizes M, N, K throughout, as the tiling
# and data movement strategy is determined during the compilation process.
# These can be influenced by introducing constraints.
@tkw.wave(constraints)
def gemm(
a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype],
b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

# This microkernel encodes the fact that if the iterate
# dimension were tiled, then we would need to materialize a loop.
@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# a_reg: tkw.Register[M, K, dtype]
a_reg = tkw.read(a)
# b_reg: tkw.Register[N, K, dtype]
b_reg = tkw.read(b)
# acc: tkw.Register[M, N, tkl.f32]
acc = tkw.mma(a_reg, b_reg, acc)
return acc

# repeat represents the results of the loop
tkw.write(repeat, c)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
BLOCK_M: block_shape[0],
BLOCK_N: block_shape[1],
BLOCK_K: block_shape[2],
M: shape[0],
N: shape[1],
K: shape[2],
}
return gemm, hyperparams


def testGemm():
gemm, hyperparams = _get_gemm_kernel(
shape=(1024, 1024, 1024), mfma_variant=MMAType.F32_16x16x16_F16
)
options = WaveCompileOptions(
subs=hyperparams,
run_bench=False,
check_water_analysis=True,
)
compiled_gemm = wave_compile(options, gemm)
assert compiled_gemm is not None


if __name__ == "__main__":
testGemm()
21 changes: 21 additions & 0 deletions water/include/water/Dialect/Wave/IR/IndexExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define WATER_DIALECT_WAVE_IR_INDEXEXPR_H

#include "mlir/IR/Attributes.h"
#include "water/Dialect/Wave/IR/WaveAttrs.h"
#include "llvm/ADT/SetVector.h"

namespace mlir {
Expand Down Expand Up @@ -41,6 +42,26 @@ aggregateAllSymbols(RangeT &&symbolLists,
mlir::AffineMap alignMapSymbols(mlir::AffineMap map,
llvm::ArrayRef<mlir::Attribute> symbols,
llvm::ArrayRef<mlir::Attribute> allSymbols);

// Create an index mapping induced by the given workgroup constraint. Combine
// it with the base index mapping if provided.
WaveIndexMappingAttr
applyConstraint(WorkgroupConstraintAttr constraint,
WaveIndexMappingAttr baseMapping = nullptr);

// Create an index mapping induced by the given tiling constraint. Combine it
// with the base index mapping if provided.
WaveIndexMappingAttr
applyConstraint(TilingConstraintAttr constraint,
WaveIndexMappingAttr baseMapping = nullptr);

// Create an index mapping induced by the given constraint. Combine it with the
// base index mapping if provided. Call `applyConstraint` if the specific kind
// of constraint is already known.
WaveIndexMappingAttr
applyConstraintGeneric(mlir::Attribute constraint,
WaveIndexMappingAttr baseMapping = nullptr);

} // namespace wave

#endif // WATER_DIALECT_WAVE_IR_INDEXEXPR_H
14 changes: 14 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def NORMAL_FORM_FULL_TYPES : I32BitEnumAttrCaseGroup<
NORMAL_FORM_FUNC_BOUNDARY, NORMAL_FORM_OP_TYPES
], "full_types">;

// When anything changes below, also update the C and Python API.
def WaveNormalFormEnum : I32BitEnumAttr<"WaveNormalForm", "", [
I32BitEnumAttrCaseNone<"None", "none">,
// Bits.
Expand Down Expand Up @@ -366,6 +367,19 @@ def WaveSymbolAttr : AttrDef<WaveDialect, "WaveSymbol"> {
}];
}

def WaveIterSymbolAttr : AttrDef<WaveDialect, "WaveIterSymbol"> {
let mnemonic = "iter";
let description = [{
Symbol referring to an induction variable that can be used in symbolic
expressions in the Wave dialect. Induction variables may have the same
name as other symbols without clashing with them.
}];

let parameters =
(ins StringRefParameter<"name of the induction variable">:$name);
let assemblyFormat = "`<` $name `>`";
}

def WaveIndexMappingAttr : AttrDef<WaveDialect, "WaveIndexMapping"> {
let mnemonic = "index_mapping";
let description = [{
Expand Down
7 changes: 5 additions & 2 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def MmaOp : WaveOp<"mma",
def IterateOp : Op<WaveDialect, "iterate", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["areTypesCompatible", "getEntrySuccessorOperands"]>]> {
["areTypesCompatible", "getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
let summary = "Executes the body repeatedly";
let description = [{
Intrinsically sequential iterative execution that is akin to a loop with
Expand Down Expand Up @@ -215,7 +216,9 @@ def YieldOp : Op<WaveDialect, "yield",
// Memory-related operations
//-----------------------------------------------------------------------------

def AllocateOp : WaveOp<"allocate"> {
def AllocateOp : WaveOp<"allocate", [
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>
]> {
let summary = "Represents an allocation in an address space";
let description = [{
Allocates memory for a Wave tensor in the address space indicated by the
Expand Down
45 changes: 45 additions & 0 deletions water/include/water/c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef WATER_C_DIALECTS_H
#define WATER_C_DIALECTS_H

#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"

#ifdef __cplusplus
Expand All @@ -15,6 +16,9 @@ extern "C" {

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Wave, wave);

/// Register the Wave dialect passes.
MLIR_CAPI_EXPORTED void mlirWaveDialectRegisterPasses();

//===---------------------------------------------------------------------===//
// Wave Dialect Constants
//===---------------------------------------------------------------------===//
Expand All @@ -36,6 +40,27 @@ mlirWaveSymbolAttrGet(MlirContext mlirCtx, MlirStringRef symbolName);
/// Returns the typeID of a WaveSymbolAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveSymbolAttrGetTypeID();

/// Gets the name of a WaveSymbolAttr.
MLIR_CAPI_EXPORTED MlirStringRef mlirWaveSymbolAttrGetName(MlirAttribute attr);

//===---------------------------------------------------------------------===//
// WaveIterSymbolAttr
//===---------------------------------------------------------------------===//

/// Checks whether the given MLIR attribute is a WaveIterSymbolAttr.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveIterSymbolAttr(MlirAttribute attr);

/// Creates a new WaveIterSymbolAttr with the given induction variable name.
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveIterSymbolAttrGet(MlirContext mlirCtx, MlirStringRef symbolName);

/// Returns the typeID of a WaveIterSymbolAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIterSymbolAttrGetTypeID();

/// Gets the induction variable name.
MLIR_CAPI_EXPORTED MlirStringRef
mlirWaveIterSymbolAttrGetName(MlirAttribute attr);

//===---------------------------------------------------------------------===//
// WaveIndexSymbolAttr
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -86,6 +111,26 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexMappingAttrGet(
/// Returns the typeID of a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexMappingAttrGetTypeID();

/// Get the offset from a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirAffineMap
mlirWaveIndexMappingAttrGetStart(MlirAttribute attr);

/// Get the step from a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirAffineMap
mlirWaveIndexMappingAttrGetStep(MlirAttribute attr);

/// Get the stride from a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirAffineMap
mlirWaveIndexMappingAttrGetStride(MlirAttribute attr);

/// Get the number of (input) symbols.
MLIR_CAPI_EXPORTED intptr_t
mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr);

/// Get the (input) symbol at the given index.
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveIndexMappingAttrGetSymbol(MlirAttribute attr, intptr_t index);

//===---------------------------------------------------------------------===//
// WaveHyperparameterAttr
//===---------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions water/lib/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ add_mlir_public_c_api_library(WaterCAPI
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRWaveDialect
MLIRWaveTransforms
)
68 changes: 64 additions & 4 deletions water/lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@

#include "water/Dialect/Wave/IR/WaveAttrs.h"
#include "water/Dialect/Wave/IR/WaveDialect.h"
#include "water/Dialect/Wave/Transforms/Passes.h"
#include "water/c/Dialects.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Wave, wave, ::wave::WaveDialect)

//===---------------------------------------------------------------------===//
// Wave Dialect Passes
//===---------------------------------------------------------------------===//

void mlirWaveDialectRegisterPasses() { wave::registerPasses(); }

//===---------------------------------------------------------------------===//
// Wave Dialect Constants
//===---------------------------------------------------------------------===//
Expand All @@ -42,6 +49,33 @@ MlirTypeID mlirWaveSymbolAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveSymbolAttr>());
}

MlirStringRef mlirWaveSymbolAttrGetName(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveSymbolAttr>(unwrap(attr)).getName());
}

//===---------------------------------------------------------------------===//
// WaveIterSymbolAttr
//===---------------------------------------------------------------------===//

bool mlirAttributeIsAWaveIterSymbolAttr(MlirAttribute attr) {
return llvm::isa<wave::WaveIterSymbolAttr>(unwrap(attr));
}

MlirAttribute mlirWaveIterSymbolAttrGet(MlirContext mlirCtx,
MlirStringRef symbolNameStrRef) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
llvm::StringRef symbolName = unwrap(symbolNameStrRef);
return wrap(wave::WaveIterSymbolAttr::get(ctx, symbolName));
}

MlirTypeID mlirWaveIterSymbolAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveIterSymbolAttr>());
}

MlirStringRef mlirWaveIterSymbolAttrGetName(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIterSymbolAttr>(unwrap(attr)).getName());
}

//===---------------------------------------------------------------------===//
// WaveIndexSymbolAttr
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -91,7 +125,8 @@ MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx,

assert(llvm::all_of(
symbolAttrs,
llvm::IsaPred<wave::WaveSymbolAttr, wave::WaveIndexSymbolAttr>) &&
llvm::IsaPred<wave::WaveSymbolAttr, wave::WaveIndexSymbolAttr,
wave::WaveIterSymbolAttr>) &&
"expected mapping to contain only WaveSymbolAttr or "
"WaveIndexSymbolAttr attributes");

Expand All @@ -103,6 +138,30 @@ MlirTypeID mlirWaveIndexMappingAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveIndexMappingAttr>());
}

MlirAffineMap mlirWaveIndexMappingAttrGetStart(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStart());
}

MlirAffineMap mlirWaveIndexMappingAttrGetStep(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStep());
}

MlirAffineMap mlirWaveIndexMappingAttrGetStride(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStride());
}

intptr_t mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr) {
return llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr))
.getSymbols()
.size();
}

MlirAttribute mlirWaveIndexMappingAttrGetSymbol(MlirAttribute attr,
intptr_t index) {
return wrap(
llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getSymbols()[index]);
}

//===---------------------------------------------------------------------===//
// WaveHyperparameterAttr
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -215,9 +274,10 @@ MlirAttribute mlirWaveExprListAttrGet(MlirAttribute *symbolNames,

assert(llvm::all_of(
symbolAttrs,
llvm::IsaPred<wave::WaveSymbolAttr, wave::WaveIndexSymbolAttr>) &&
"expected mapping to contain only WaveSymbolAttr or "
"WaveIndexSymbolAttr attributes");
llvm::IsaPred<wave::WaveSymbolAttr, wave::WaveIndexSymbolAttr,
wave::WaveIterSymbolAttr>) &&
"expected mapping to contain only WaveSymbolAttr, "
"WaveIndexSymbolAttr or WaveIterSymbolAttr attributes");

return wrap(wave::WaveExprListAttr::get(ctx, symbolAttrs, unwrap(map)));
}
Expand Down
Loading