Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/AmdArchDb.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#define MLIR_DIALECT_ROCK_IR_AMDARCHDB_H

#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {
namespace rock {
Expand Down Expand Up @@ -52,9 +54,44 @@ struct AmdArchInfo {
/// Get the default features for the pair <arch, datatype>
GemmFeatures getDefaultFeatures(Type dataType);

/// Get the default features for multiple types (intersects features)
GemmFeatures getDefaultFeatures(ArrayRef<Type> types);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the idea was to get rid of GemmFeatures? At least that's what I think the end goal should be. Isn't it possible to remove GemmFeatures completely in this PR?


/// Get the maximum LDS vector length for the given architecture and element
/// bit width
int64_t getMaxLDSVectorLength(int64_t elementBitWidth);

// Feature check methods

/// Check if accelerator (mfma/wmma) is supported for given types
bool isAccel(Type dataTypeA, Type dataTypeB);

/// Check if mfma is supported for given types
bool isMfma(Type dataTypeA, Type dataTypeB);

/// Check if wmma is supported for given types
bool isWmma(Type dataTypeA, Type dataTypeB);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do isAccel() instead of isMfma/isWmma?


/// Check if direct-to-LDS is supported for given type and numBytes
bool isDirectToLDS(Type dataType, int64_t numBytes = 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the dataType enough? why do we need numBytes?


/// Check if async direct-to-LDS is supported (needs arch string + type)
bool isAsyncDirectToLDS(StringRef arch, Type dataType, int64_t numBytes);

/// Check if dot product is supported (arch-only, no type dependency)
bool hasDot() const;

/// Check if atomic add is supported for given type
bool hasAtomicAdd(Type dataType);

/// Check if f16 atomic add is supported (arch-only)
bool hasAtomicAddF16() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: hasAtomicAdd(Type) instead of three similar functions.


/// Check if bf16 atomic add is supported (arch-only)
bool hasAtomicAddBF16() const;

/// Check if f32 atomic fmax is supported (arch-only)
bool hasAtomicFmaxF32() const;
};

AmdArchInfo lookupArchInfo(StringRef arch);
Expand Down
22 changes: 11 additions & 11 deletions mlir/include/mlir/Dialect/Rock/IR/GetRockInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
#define MLIR_DIALECT_ROCK_IR_GETROCKINFO_H

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
Expand All @@ -29,12 +31,6 @@ namespace rock {
// This function returns the func or gpu.func of a given op
Operation *getParentFuncOp(Operation *op);

// Return a boolean if the features contain accel properties
bool isAccel(rock::GemmFeatures features);

// Get the arch from the op
FailureOr<StringAttr> getArch(Operation *op);

// Get the arch from the op and error out if it cannot be found
StringAttr getArchValue(Operation *op);

Expand All @@ -49,13 +45,17 @@ inline rock::GemmFeatures intersectGemmFeatures(rock::GemmFeatures a,
return a & b;
}

// Get the features enabled for the specified op. These will be dependent on
// the architecture being used, and the type of the op.
rock::GemmFeatures getFeatures(Operation *op);

// Check if a schedule version is supported by the hardware
LogicalResult isScheduleVersionSupported(int64_t scheduleVersion,
GemmFeatures features, StringRef arch);
AmdArchInfo archInfo,
ArrayRef<Type> types,
StringRef arch);

// Check if features contain accelerator (mfma or wmma)
// This is a helper function for code that still has GemmFeatures
inline bool isAccel(GemmFeatures features) {
return bitEnumContainsAny(features, GemmFeatures::wmma | GemmFeatures::mfma);
}

} // End namespace rock
} // End namespace mlir
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct Serializable {
s = std::string(s.begin() + matches[0].size(), s.end());
} else {
// unknown perf config version
llvm::errs() << "deserialize: unknown perf config version: " << value << "\n";
return false;
}
} else {
Expand All @@ -106,6 +107,7 @@ struct Serializable {

if (!checkVersionFormat(s)) {
// incorrect perf config format
llvm::errs() << "deserialize: incorrect perf config format: " << s << "\n";
return false;
}

Expand All @@ -117,8 +119,10 @@ struct Serializable {
std::bind(DeserializeField{}, std::ref(ok), std::ref(ss),
Seperator, std::placeholders::_1));

if (!ok)
if (!ok) {
llvm::errs() << "deserialize: failed to deserialize perf config: " << s << "\n";
return false;
}

static_cast<Derived &>(*this) = out;
return true;
Expand Down
88 changes: 43 additions & 45 deletions mlir/lib/Conversion/RockToGPU/RockToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,46 +150,41 @@ static void runWavesPerEUHeuristic(OpBuilder b, gpu::GPUFuncOp gpuFunc,
return;
}
int64_t gridSize = gpuFunc->getAttrOfType<IntegerAttr>("grid_size").getInt();
FailureOr<StringAttr> maybeArch = rock::getArch(gpuFunc);
if (succeeded(maybeArch)) {
StringAttr arch = maybeArch.value();
rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch);
FailureOr<int64_t> maybeNumCU = rock::getNumCU(gpuFunc);
int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU);
int64_t totalEUs = archInfo.numEUPerCU * numCU;
int64_t wavesPerBlock = (blockSize / archInfo.waveSize);
int64_t totalWaves = wavesPerBlock * gridSize;
int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU;
int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs;
int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid);
LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n");
LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n");
LLVM_DEBUG(llvm::dbgs() << " numEUPerCU:" << archInfo.numEUPerCU << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n");
LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n");
// limit wavesPerEU based on lds usage
if (ldsUsage > 0) {
wavesPerEU =
std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage);
}
// Currently limiting wavesPerEU to be two
// it is a future to ticket to remove this constraint with further
// analysis
constexpr int64_t wavesPerEUUpperBound = 2;
wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound);
if (wavesPerEU > 1) {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n");
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));
} else {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set"
<< "\n");
}
StringAttr arch = rock::getArchValue(gpuFunc);
rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch);
FailureOr<int64_t> maybeNumCU = rock::getNumCU(gpuFunc);
int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU);
int64_t totalEUs = archInfo.numEUPerCU * numCU;
int64_t wavesPerBlock = (blockSize / archInfo.waveSize);
int64_t totalWaves = wavesPerBlock * gridSize;
int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU;
int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs;
int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid);
LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n");
LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n");
LLVM_DEBUG(llvm::dbgs() << " numEUPerCU:" << archInfo.numEUPerCU << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n");
LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n");
// limit wavesPerEU based on lds usage
if (ldsUsage > 0) {
wavesPerEU =
std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage);
}
// Currently limiting wavesPerEU to be two
// it is a future to ticket to remove this constraint with further
// analysis
constexpr int64_t wavesPerEUUpperBound = 2;
wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound);
if (wavesPerEU > 1) {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n");
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));
} else {
LLVM_DEBUG(llvm::dbgs() << "arch not found.\n");
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set"
<< "\n");
}
}

Expand Down Expand Up @@ -219,14 +214,20 @@ void LowerRockOpsToGPUPass::runOnOperation() {
// Make sure that the function has the necessary attributes.
auto blockSizeAttr = theFunc->getAttr("block_size");
auto gridSizeAttr = theFunc->getAttr("grid_size");
auto archAttr = theFunc->getAttr("arch");
auto mhalArchAttr = theFunc->getAttr("mhal.arch");
if (!blockSizeAttr) {
return theFunc->emitError()
<< "kernel func op is missing the block_size attribute";
<< "kernel func op '" << theFunc.getName() << "' is missing the block_size attribute";
}
if (!gridSizeAttr) {
return theFunc->emitError()
<< "kernel func op is missing the grid_size attribute";
<< "kernel func op '" << theFunc.getName() << "' is missing the grid_size attribute";
}
// if (!archAttr && !mhalArchAttr) {
// return theFunc->emitError()
// << "kernel func op '" << theFunc.getName() << "' is missing both arch and mhal.arch attributes";
// }

// Set up the symbol table for the GPU ModuleOp.
SymbolTable gpuModuleSymbolTable(gpuMod);
Expand Down Expand Up @@ -266,10 +267,7 @@ void LowerRockOpsToGPUPass::runOnOperation() {
gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
}

FailureOr<StringAttr> maybeArch = rock::getArch(theFunc);
if (succeeded(maybeArch)) {
gpuFunc->setAttr("arch", maybeArch.value());
}
gpuFunc->setAttr("arch", rock::getArchValue(theFunc));
FailureOr<int64_t> maybeNumCU = rock::getNumCU(theFunc);
if (succeeded(maybeNumCU)) {
gpuFunc->setAttr("num_cu", b.getI64IntegerAttr(maybeNumCU.value()));
Expand Down
6 changes: 1 addition & 5 deletions mlir/lib/Conversion/TosaToRock/TosaToRock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ static Value expandTensor(PatternRewriter &rw, Operation *op, Value operand,

static rock::GemmFeatures getGemmFeaturesFromOp(Operation *op, Type inputType) {
// Start by getting the arch from the Tosa op
StringAttr arch = StringAttr::get(op->getContext(), "");
FailureOr<StringAttr> maybeArch = rock::getArch(op);
if (succeeded(maybeArch)) {
arch = maybeArch.value();
}
StringAttr arch = rock::getArchValue(op);

// Now we can lookup the default features from the arch
rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch);
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/Rock/Generator/ConvGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ LogicalResult ConvGenerator::getBwdWeightKernelCount(OpBuilder &builder,
assert(config.operation.value() == ConvOpType::BwdWeight);

kernelCount = 1;
if (isAccel(config.features)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove isAccel()?

bool isAccel = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma);
if (isAccel) {
bool needExtraPad = false;
if (failed(needExtraPadBwdWeight(builder, needExtraPad))) {
return failure();
Expand Down Expand Up @@ -365,7 +366,8 @@ LogicalResult ConvGenerator::needExtraPadBwdWeight(OpBuilder &builder,
/*batchSize=*/convDims.n,
/*numCU=*/getNumCU()};

if (isAccel(config.features)) {
bool isAccel2 = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma);
if (isAccel2) {
auto populateParamsAccelPtr = PopulateParamsAccel::select(config.features);
InitParamsAccel validParams;
auto res = populateParamsAccelPtr->obtainTuningParameters(
Expand Down Expand Up @@ -403,7 +405,8 @@ LogicalResult ConvGenerator::hasWorkspace(OpBuilder &builder,
if (config.operation.has_value()) {
Type dataType = getInputDataType(builder);
ConvOpType dir = config.operation.value();
if ((dir == ConvOpType::BwdWeight) && isAccel(config.features) &&
bool isAccel3 = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma);
if ((dir == ConvOpType::BwdWeight) && isAccel3 &&
(dataType == builder.getF16Type())) {
// In case we need extra padding, do not use workspace.
bool needPadding = false;
Expand Down Expand Up @@ -982,7 +985,8 @@ LogicalResult ConvGenerator::genConvModule(ModuleOp &module, int kernelId,

bool needsZeroInit = false;
bool needExtraPad = false;
if (rock::isAccel(config.features) &&
bool isAccel4 = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma);
if (isAccel4 &&
succeeded(needExtraPadBwdWeight(builder, needExtraPad))) {
if (!needExtraPad) {
auto dataType = getInputDataType(builder);
Expand Down
Loading