Skip to content

Commit fe0e18c

Browse files
committed
[TOSA] Add legalization for avg_pool2d
Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. Signed-off-by: Vitalii Shutov <[email protected]> Change-Id: I73fa26a58379e2c021929ade81c983ff91c59667
1 parent 8d563af commit fe0e18c

File tree

5 files changed

+180
-59
lines changed

5 files changed

+180
-59
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
107107
Type inputElemTy, Type outputElemTy,
108108
ArrayRef<int64_t> weightShape);
109109

110+
// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
111+
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as
112+
// {top, bottom, left, right}. Returns the padded tensor value.
113+
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
114+
Operation *op, Value inputNHWC,
115+
ArrayRef<int64_t> padExtents);
116+
110117
} // namespace tosa
111118
} // namespace mlir
112119

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6123,7 +6123,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61236123
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
61246124
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
61256125
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6126-
DenseI64ArrayAttr &pad) {
6126+
DenseI64ArrayAttr &pad,
6127+
SmallVectorImpl<int64_t> *explicitNHWCPad = nullptr) {
61276128

61286129
RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
61296130
if (!inputTy)
@@ -6163,21 +6164,43 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61636164

61646165
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
61656166
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6166-
// Currently, we can not represent `count_include_pad` with the existing
6167-
// TOSA AvgPool2d specification. Without the below check, we produce silent
6168-
// wrong answer (SWA) when the `count_include_pad` value is `true.`
6169-
//
6170-
// Note: We need to check for `count_include_pad` only when the `padding`
6171-
// value is non-zero.
6167+
// When count_include_pad=true with non-zero padding, we will materialize an
6168+
// explicit pad after transposing to NHWC. Track the padding extents and
6169+
// zero out the TOSA op padding so the divisor matches the full kernel size.
61726170
bool countIncludePad;
61736171
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
61746172
(!matchPattern(op.getCountIncludePad(),
61756173
m_TorchConstantBool(&countIncludePad)) ||
61766174

61776175
countIncludePad)) {
6178-
return rewriter.notifyMatchFailure(
6179-
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6180-
"`count_include_pad` value should be `False`.");
6176+
if (!explicitNHWCPad)
6177+
return rewriter.notifyMatchFailure(
6178+
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6179+
"`count_include_pad` value should be `False`.");
6180+
6181+
// Remember the spatial padding so we can emit an NHWC tosa.pad right
6182+
// after the transpose.
6183+
explicitNHWCPad->assign(
6184+
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
6185+
6186+
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
6187+
if (ShapedType::isDynamic(dim))
6188+
return ShapedType::kDynamic;
6189+
return dim + before + after;
6190+
};
6191+
6192+
// Update the logical input type used for shape computations to include
6193+
// the extra zeros supplied by the explicit pad.
6194+
SmallVector<int64_t> paddedShape(inputTy.getShape().begin(),
6195+
inputTy.getShape().end());
6196+
// Height stored at rank-2, width at rank-1 for NCHW shapes.
6197+
paddedShape[inputRank - 2] =
6198+
addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]);
6199+
paddedShape[inputRank - 1] =
6200+
addPad(paddedShape[inputRank - 1], paddingInts[1], paddingInts[1]);
6201+
inputTy = RankedTensorType::get(paddedShape, inputTy.getElementType());
6202+
6203+
paddingInts.assign(/*Count=*/2, /*Value=*/0);
61816204
}
61826205
}
61836206

@@ -6321,15 +6344,23 @@ class ConvertAtenAvgPool2dOp
63216344
}
63226345

63236346
SmallVector<int64_t, 2> dilationArray{1, 1};
6347+
SmallVector<int64_t, 4> explicitNHWCPad;
63246348
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
63256349
tosa::AvgPool2dOp>(
6326-
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6350+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
6351+
&explicitNHWCPad)))
63276352
return rewriter.notifyMatchFailure(
63286353
op, "invalid pooling parameters or input type");
63296354

6330-
// Transpose to xHWC
6331-
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6332-
transposePoolingInputToHwc(op, rewriter, self);
6355+
Value transposed =
6356+
ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6357+
transposePoolingInputToHwc(op, rewriter, self);
6358+
6359+
if (!explicitNHWCPad.empty())
6360+
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
6361+
transposed, explicitNHWCPad);
6362+
6363+
input = transposed;
63336364

63346365
return success();
63356366
}
@@ -6372,16 +6403,23 @@ class ConvertAtenAvgPool1dOp
63726403
.getResult();
63736404

63746405
SmallVector<int64_t, 2> dilationArray{1, 1};
6406+
SmallVector<int64_t, 4> explicitNHWCPad;
63756407
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
63766408
tosa::AvgPool2dOp>(
63776409
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6378-
pad)))
6410+
pad, &explicitNHWCPad)))
63796411
return rewriter.notifyMatchFailure(
63806412
op, "invalid pooling parameters or input type");
63816413

6382-
// Transpose to xHWC
6383-
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6384-
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
6414+
Value transposed =
6415+
ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6416+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
6417+
6418+
if (!explicitNHWCPad.empty())
6419+
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
6420+
transposed, explicitNHWCPad);
6421+
6422+
input = transposed;
63856423

63866424
return success();
63876425
}

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,5 +624,43 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
624624
}
625625
}
626626

627+
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
628+
Operation *op, Value inputNHWC,
629+
ArrayRef<int64_t> padExtents) {
630+
assert(padExtents.size() == 4 && "expected [top, bottom, left, right]");
631+
632+
if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; }))
633+
return inputNHWC;
634+
635+
SmallVector<int64_t, 8> nhwcPadding = {
636+
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
637+
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);
638+
639+
auto inputTy = cast<RankedTensorType>(inputNHWC.getType());
640+
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
641+
inputTy.getShape().end());
642+
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
643+
if (ShapedType::isDynamic(dim))
644+
return ShapedType::kDynamic;
645+
return dim + before + after;
646+
};
647+
resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]);
648+
resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]);
649+
650+
auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType());
651+
652+
Type elemTy = inputTy.getElementType();
653+
Value padConst;
654+
if (isa<mlir::FloatType>(elemTy)) {
655+
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
656+
} else {
657+
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
658+
}
659+
660+
return tosa::PadOp::create(rewriter, loc, resultTy, inputNHWC, nhwcPadShape,
661+
padConst)
662+
.getResult();
663+
}
664+
627665
} // namespace tosa
628666
} // namespace mlir

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3532,7 +3532,6 @@
35323532
"AtenSymConstrainRangeForSize_basic",
35333533
"AtenSymConstrainRange_basic",
35343534
"Aten_AssertScalar_basic",
3535-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
35363535
"ScatterAddDynamicModule_basic",
35373536
"UniformModule_basic",
35383537
"UniformStaticShapeModule_basic",
@@ -3655,21 +3654,14 @@
36553654
"AtenTopKModule_basic",
36563655
"AtenTopKSmallestModule_basic",
36573656
"Aten_EmbeddingBagExample_basic",
3658-
"AvgPool1dFloatModule_basic",
36593657
"AvgPool1dIntModule_basic",
36603658
"AvgPool1dStaticModule_basic",
3661-
"AvgPool2dCeilModeTrueModule_basic",
36623659
"AvgPool1dNoPadCeilPadNotIncluded_basic",
36633660
"AvgPool1dPadCeilPadNotIncluded_basic",
3664-
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
3665-
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
3666-
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
36673661
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
36683662
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
36693663
"AvgPool2dDivisorOverrideModule_basic",
3670-
"AvgPool2dFloatModule_basic",
36713664
"AvgPool2dIntModule_basic",
3672-
"AvgPool2dStaticModule_basic",
36733665
"BernoulliFloatModule_basic",
36743666
"BernoulliPModule_basic",
36753667
"BernoulliTensorModule_basic",

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,24 +2307,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
23072307

23082308
// -----
23092309

2310-
func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
2311-
%int0 = torch.constant.int 0
2312-
%int1 = torch.constant.int 1
2313-
%int3 = torch.constant.int 3
2314-
%false= torch.constant.bool false
2315-
%count_include_pad = torch.constant.bool true
2316-
%divisor_override = torch.constant.none
2317-
2318-
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2319-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2320-
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2321-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2322-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
2323-
return %3 : !torch.vtensor<[1,192,35,35],f32>
2324-
}
2325-
2326-
// -----
2327-
23282310
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
23292311
%int0 = torch.constant.int 0
23302312
%int1 = torch.constant.int 1
@@ -2844,21 +2826,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
28442826

28452827
// -----
28462828

2847-
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
2848-
%int1 = torch.constant.int 1
2849-
%int3 = torch.constant.int 3
2850-
%false = torch.constant.bool false
2851-
%count_include_pad = torch.constant.bool true
2852-
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
2853-
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2854-
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2855-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2856-
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
2857-
return %3 : !torch.vtensor<[1,512,10],f32>
2858-
}
2859-
2860-
// -----
2861-
28622829
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
28632830
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
28642831
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4384,3 +4351,82 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
43844351
%out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32>
43854352
return %out : !torch.vtensor<[1,0,256],f32>
43864353
}
4354+
4355+
// -----
4356+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4357+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4358+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4359+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4360+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4361+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4362+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4363+
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4364+
// CHECK: %[[VAL_7:.*]] = torch.constant.none
4365+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4366+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4367+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4368+
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4369+
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4370+
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4371+
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4372+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4373+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4374+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4375+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4376+
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4377+
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4378+
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4379+
// CHECK: }
4380+
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4381+
%int0 = torch.constant.int 0
4382+
%int1 = torch.constant.int 1
4383+
%int3 = torch.constant.int 3
4384+
%false= torch.constant.bool false
4385+
%count_include_pad = torch.constant.bool true
4386+
%divisor_override = torch.constant.none
4387+
4388+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
4389+
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4390+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4391+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
4392+
return %3 : !torch.vtensor<[1,192,35,35],f32>
4393+
}
4394+
4395+
// -----
4396+
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4397+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4398+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4399+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4400+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4401+
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4402+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4403+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4404+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4405+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4406+
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4407+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4408+
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4409+
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4410+
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4411+
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4412+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4413+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4414+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4415+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4416+
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4417+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4418+
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4419+
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4420+
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4421+
// CHECK: }
4422+
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4423+
%int1 = torch.constant.int 1
4424+
%int3 = torch.constant.int 3
4425+
%false = torch.constant.bool false
4426+
%count_include_pad = torch.constant.bool true
4427+
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
4428+
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4429+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4430+
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
4431+
return %3 : !torch.vtensor<[1,512,10],f32>
4432+
}

0 commit comments

Comments
 (0)