-
Notifications
You must be signed in to change notification settings - Fork 51
[WIP] Refactor GemmFeatures #2185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
a446365
b04eaa6
186656b
7fd7ad9
9f63716
500ce7e
9074575
9da1d69
48c5f96
f8b565d
a41ad62
e4867bd
9a38276
e5caab4
a1a2512
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,9 @@ | |
| #define MLIR_DIALECT_ROCK_IR_AMDARCHDB_H | ||
|
|
||
| #include "mlir/Dialect/Rock/IR/Rock.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/Support/LLVM.h" | ||
| #include "llvm/ADT/ArrayRef.h" | ||
|
|
||
| namespace mlir { | ||
| namespace rock { | ||
|
|
@@ -52,9 +54,44 @@ struct AmdArchInfo { | |
| /// Get the default features for the pair <arch, datatype> | ||
| GemmFeatures getDefaultFeatures(Type dataType); | ||
|
|
||
| /// Get the default features for multiple types (intersects features) | ||
| GemmFeatures getDefaultFeatures(ArrayRef<Type> types); | ||
|
|
||
| /// Get the maximum LDS vector length for the given architecture and element | ||
| /// bit width | ||
| int64_t getMaxLDSVectorLength(int64_t elementBitWidth); | ||
|
|
||
| // Feature check methods | ||
|
|
||
| /// Check if accelerator (mfma/wmma) is supported for given types | ||
| bool isAccel(Type dataTypeA, Type dataTypeB); | ||
|
|
||
| /// Check if mfma is supported for given types | ||
| bool isMfma(Type dataTypeA, Type dataTypeB); | ||
|
|
||
| /// Check if wmma is supported for given types | ||
| bool isWmma(Type dataTypeA, Type dataTypeB); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do isAccel() instead of isMfma/isWmma? |
||
|
|
||
| /// Check if direct-to-LDS is supported for given type and numBytes | ||
| bool isDirectToLDS(Type dataType, int64_t numBytes = 0); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't the dataType enough? why do we need numBytes? |
||
|
|
||
| /// Check if async direct-to-LDS is supported (needs arch string + type) | ||
| bool isAsyncDirectToLDS(StringRef arch, Type dataType, int64_t numBytes); | ||
|
|
||
| /// Check if dot product is supported (arch-only, no type dependency) | ||
| bool hasDot() const; | ||
|
|
||
| /// Check if atomic add is supported for given type | ||
| bool hasAtomicAdd(Type dataType); | ||
|
|
||
| /// Check if f16 atomic add is supported (arch-only) | ||
| bool hasAtomicAddF16() const; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: hasAtomicAdd(Type) instead of three similar functions. |
||
|
|
||
| /// Check if bf16 atomic add is supported (arch-only) | ||
| bool hasAtomicAddBF16() const; | ||
|
|
||
| /// Check if f32 atomic fmax is supported (arch-only) | ||
| bool hasAtomicFmaxF32() const; | ||
| }; | ||
|
|
||
| AmdArchInfo lookupArchInfo(StringRef arch); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -271,7 +271,8 @@ LogicalResult ConvGenerator::getBwdWeightKernelCount(OpBuilder &builder, | |
| assert(config.operation.value() == ConvOpType::BwdWeight); | ||
|
|
||
| kernelCount = 1; | ||
| if (isAccel(config.features)) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we remove isAccel()? |
||
| bool isAccel = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma); | ||
| if (isAccel) { | ||
| bool needExtraPad = false; | ||
| if (failed(needExtraPadBwdWeight(builder, needExtraPad))) { | ||
| return failure(); | ||
|
|
@@ -365,7 +366,8 @@ LogicalResult ConvGenerator::needExtraPadBwdWeight(OpBuilder &builder, | |
| /*batchSize=*/convDims.n, | ||
| /*numCU=*/getNumCU()}; | ||
|
|
||
| if (isAccel(config.features)) { | ||
| bool isAccel2 = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma); | ||
| if (isAccel2) { | ||
| auto populateParamsAccelPtr = PopulateParamsAccel::select(config.features); | ||
| InitParamsAccel validParams; | ||
| auto res = populateParamsAccelPtr->obtainTuningParameters( | ||
|
|
@@ -403,7 +405,8 @@ LogicalResult ConvGenerator::hasWorkspace(OpBuilder &builder, | |
| if (config.operation.has_value()) { | ||
| Type dataType = getInputDataType(builder); | ||
| ConvOpType dir = config.operation.value(); | ||
| if ((dir == ConvOpType::BwdWeight) && isAccel(config.features) && | ||
| bool isAccel3 = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma); | ||
| if ((dir == ConvOpType::BwdWeight) && isAccel3 && | ||
| (dataType == builder.getF16Type())) { | ||
| // In case we need extra padding, do not use workspace. | ||
| bool needPadding = false; | ||
|
|
@@ -982,7 +985,8 @@ LogicalResult ConvGenerator::genConvModule(ModuleOp &module, int kernelId, | |
|
|
||
| bool needsZeroInit = false; | ||
| bool needExtraPad = false; | ||
| if (rock::isAccel(config.features) && | ||
| bool isAccel4 = bitEnumContainsAny(config.features, GemmFeatures::wmma | GemmFeatures::mfma); | ||
| if (isAccel4 && | ||
| succeeded(needExtraPadBwdWeight(builder, needExtraPad))) { | ||
| if (!needExtraPad) { | ||
| auto dataType = getInputDataType(builder); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the idea was to get rid of
GemmFeatures? At least that's what I think the end goal should be. Isn't it possible to remove GemmFeatures completely in this PR?