@@ -2307,24 +2307,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
23072307
23082308// -----
23092309
2310- func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2311- %int0 = torch.constant.int 0
2312- %int1 = torch.constant.int 1
2313- %int3 = torch.constant.int 3
2314- %false = torch.constant.bool false
2315- %count_include_pad = torch.constant.bool true
2316- %divisor_override = torch.constant.none
2317-
2318- %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2319- %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2320- %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2321- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2322- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2323- return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2324- }
2325-
2326- // -----
2327-
23282310func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
23292311 %int0 = torch.constant.int 0
23302312 %int1 = torch.constant.int 1
@@ -2844,21 +2826,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
28442826
28452827// -----
28462828
2847- func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
2848- %int1 = torch.constant.int 1
2849- %int3 = torch.constant.int 3
2850- %false = torch.constant.bool false
2851- %count_include_pad = torch.constant.bool true
2852- %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
2853- %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2854- %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2855- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2856- %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
2857- return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
2858- }
2859-
2860- // -----
2861-
28622829// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
28632830// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
28642831// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4384,3 +4351,82 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
43844351 %out = torch.aten.empty.memory_format %2452 , %none , %none , %cpu , %false , %none : !torch.list <int >, !torch.none , !torch.none , !torch.Device , !torch.bool , !torch.none -> !torch.vtensor <[1 ,0 ,256 ],f32 >
43854352 return %out : !torch.vtensor <[1 ,0 ,256 ],f32 >
43864353}
4354+
4355+ // -----
4356+ // CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4357+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4358+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4359+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4360+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4361+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4362+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4363+ // CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4364+ // CHECK: %[[VAL_7:.*]] = torch.constant.none
4365+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4366+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4367+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4368+ // CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4369+ // CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4370+ // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4371+ // CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4372+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4373+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4374+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4375+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4376+ // CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4377+ // CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4378+ // CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4379+ // CHECK: }
4380+ func.func @torch.aten.avg_pool2d.count_include_pad (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
4381+ %int0 = torch.constant.int 0
4382+ %int1 = torch.constant.int 1
4383+ %int3 = torch.constant.int 3
4384+ %false = torch.constant.bool false
4385+ %count_include_pad = torch.constant.bool true
4386+ %divisor_override = torch.constant.none
4387+
4388+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
4389+ %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4390+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4391+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4392+ return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4393+ }
4394+
4395+ // -----
4396+ // CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4397+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4398+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4399+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4400+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4401+ // CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4402+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4403+ // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4404+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4405+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4406+ // CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4407+ // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4408+ // CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4409+ // CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4410+ // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4411+ // CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4412+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4413+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4414+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4415+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4416+ // CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4417+ // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4418+ // CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4419+ // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4420+ // CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4421+ // CHECK: }
4422+ func.func @torch.aten.avg_pool1d.count_include_pad (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
4423+ %int1 = torch.constant.int 1
4424+ %int3 = torch.constant.int 3
4425+ %false = torch.constant.bool false
4426+ %count_include_pad = torch.constant.bool true
4427+ %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
4428+ %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4429+ %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4430+ %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
4431+ return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
4432+ }
0 commit comments