-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU][TransformOps] Add set_op_layout_attr op #166854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Tuomas Kärnä (tkarna) ChangesAdds Also adds For reference, the rationale behind xegpu transform ops is outlined in this RFC document. Patch is 23.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166854.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..4e0eae1007c8f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
}
+def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Set xegpu.layout attribute of an op.";
+ let description = [{
+ Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
+ `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
+ target operand/result value is defined by the `index` argument. The layout
+ is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface : $target,
+ DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedAttr<UnitAttr, "false">:$result
+ );
+
+ let results = (outs);
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "int64_t":$index,
+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
+ "ArrayRef<OpFoldResult>":$mixedSgData,
+ "ArrayRef<OpFoldResult>":$mixedInstData,
+ CArg<"bool", "false">:$result
+ )>,
+ ];
+
+ let assemblyFormat = [{
+ $target (`result` $result^)? (`index` `=` $index^)?
+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+ (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ attr-dict `:` qualified(type(operands))
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgData(), getSgData(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticInstData(), getInstData(), b);
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..456cfb9ddd2bc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
/*order=*/nullptr);
}
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
+ transform::TransformState &state,
+ TransformOpInterface transformOp,
+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ xegpu::LayoutAttr &layoutAttr) {
+ SmallVector<int32_t> sgLayout, sgData, instData;
+ auto status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ layoutAttr =
+ createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
Operation *target = *targetOps.begin();
- SmallVector<int32_t> sgLayout;
- DiagnosedSilenceableFailure status =
- convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;
- SmallVector<int32_t> sgData;
- status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
- if (!status.succeeded())
- return status;
-
- SmallVector<int32_t> instData;
- status =
- convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
- if (!status.succeeded())
- return status;
- auto maybeInstData = instData.empty()
- ? std::nullopt
- : std::optional<ArrayRef<int32_t>>(instData);
-
// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
// Set layout attr in desc op's return type. Replaces old desc op.
- auto layoutAttr =
- createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
// Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}
+void transform::SetOpLayoutAttrOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData, bool result) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*index=*/index,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ bool resultTarget = getResult();
+
+ int64_t index = getIndex();
+ if (resultTarget && index >= target->getNumResults()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op results";
+ }
+ if (!resultTarget && index >= target->getNumOperands()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op operands";
+ }
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ // Set layout attribute for the op result or operand
+ if (resultTarget) {
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+ } else {
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..46a1f032630d1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,50 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
+ """Specialization for SetOpLayoutAttrOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ sg_layout: MixedValues,
+ sg_data: MixedValues,
+ *,
+ inst_data: MixedValues = None,
+ index: Union[int, Attribute] = None,
+ result: Union[bool, Attribute] = None,
+ loc=None,
+ ip=None,
+ ):
+ inst_data = [] if inst_data is None else inst_data
+ (
+ dynamic_sg_layout,
+ static_sg_layout,
+ _,
+ ) = _dispatch_dynamic_index_list(sg_layout)
+ (
+ dynamic_sg_data,
+ static_sg_data,
+ _,
+ ) = _dispatch_dynamic_index_list(sg_data)
+ (
+ dynamic_inst_data,
+ static_inst_data,
+ _,
+ ) = _dispatch_dynamic_index_list(inst_data)
+ super().__init__(
+ _get_op_result_or_value(target),
+ dynamic_sg_layout,
+ dynamic_sg_data,
+ dynamic_inst_data,
+ static_sg_layout=static_sg_layout,
+ static_sg_data=static_sg_data,
+ static_inst_data=static_inst_data,
+ index=index,
+ result=result,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..726b6748452ae 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Index exceeds the number of op results}}
+ transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Index exceeds the number of op operands}}
+ transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..089a8fb4fd9b6 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+ %2 = arith.extf %1 : vector<256x32...
[truncated]
|
Adds
transform.xegpu.set_op_layout_attrtransform op that attachesxegpu.layoutattribute to the target op.Also adds
getLayoutAttrFromOperandsutility function.For reference, the rationale behind xegpu transform ops is outlined in this RFC document.