diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..0f19dad4cb1f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1089,6 +1089,11 @@ class ConvertAtenMultipleDimsReductionOp for (int64_t i = 0; i < inputRank; i++) reduceDims.push_back(i); } + // PyTorch treats an explicit empty list the same as "reduce all dims". + if (reduceDims.empty()) { + for (int64_t i = 0; i < inputRank; i++) + reduceDims.push_back(i); + } int64_t N = reduceDims.size(); for (unsigned i = 0; i < N; i++) { diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 02d1390ed148..444a2bdd2508 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -782,13 +782,23 @@ std::optional convertReduceOpCommon( // Optionally squeeze out the reduced axes. if (!keep_dims) { + auto squeezedType = + RankedTensorType::get(output_shape, reduce_element_type); auto reshape_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, val, + rewriter, op->getLoc(), squeezedType, val, tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape)); val = reshape_op.getResult(); } } + // Ensure the result element type matches the expected output type. + if (val.getType() != output_type) { + auto casted = tosa::tosaCastTensorToType(rewriter, val, output_type); + if (!casted) + return std::nullopt; + val = casted.value(); + } + return val; } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 81071c6ab058..efbfaf259ac2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3434,6 +3434,8 @@ "ElementwiseClampMinModule_bfloat16", "ElementwiseClampModule_bfloat16", "ElementwiseReluModule_bfloat16", + # torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '' + "ReduceSumEmptyDimListInt8ToInt32Module_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3846,7 +3848,6 @@ "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", - "MeanDimEmptyDimModule_basic", "MlGroupNormManualModule_basic", "MlGroupNormModule_basic", "MlLayerNormManualModule_basic", @@ -3901,7 +3902,6 @@ "ReduceL3NormKeepDimComplexModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceSumDimIntListEmptyDimModule_basic", "RollModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 0eb0545e7f11..2e4ba9c4ccfc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -58,6 +58,52 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumEmptyDimListInt8ToInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[], dtype=torch.int32) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8ToInt32Module()) +def ReduceSumEmptyDimListInt8ToInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + +class ReduceSumEmptyDimListInt8Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8Module()) +def ReduceSumEmptyDimListInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + class ReduceSumElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d100fe9dcfde..543dc09a65b2 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -311,6 +311,29 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> ! // ----- +// CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[],f32> +// CHECK: } +func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %empty = torch.prim.ListConstruct : () -> !torch.list + %0 = torch.aten.sum.dim_IntList %arg0, %empty, %false, %none : !torch.vtensor<[2,3,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_linalg_vector_norm$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>