Skip to content

Commit 384e09c

Browse files
Merge pull request #2087 from j2kun:64x64-reduce
PiperOrigin-RevId: 794259650
2 parents c7bca25 + 2c4976a commit 384e09c

File tree

11 files changed

+167
-179
lines changed

11 files changed

+167
-179
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
func.func @roberts_cross(%img: tensor<256xi16> {secret.secret}) -> tensor<256xi16> {
2+
%c256 = arith.constant 256 : index
3+
%c16 = arith.constant 16 : index
4+
%c1 = arith.constant 1 : index
5+
%c0 = arith.constant 0 : index
6+
%c-1 = arith.constant -1 : index
7+
8+
// Each point p = img[x][y], where x is row and y is column, in the new image will equal:
9+
// (img[x-1][y-1] - img[x][y])^2 + (img[x-1][y] - img[x][y-1])^2
10+
%r = affine.for %x = 0 to 16 iter_args(%imgx = %img) -> tensor<256xi16> {
11+
%1 = affine.for %y = 0 to 16 iter_args(%imgy = %imgx) -> tensor<256xi16> {
12+
13+
// fetch img[x-1][y-1]
14+
%4 = arith.addi %x, %c-1 : index
15+
%5 = arith.muli %4, %c16 : index
16+
%6 = arith.addi %y, %c-1 : index
17+
%7 = arith.addi %5, %6 : index
18+
%8 = arith.remui %7, %c256 : index
19+
%9 = tensor.extract %img[%8] : tensor<256xi16>
20+
21+
// fetch img[x][y]
22+
%10 = arith.muli %x, %c16 : index
23+
%11 = arith.addi %10, %y : index
24+
%12 = arith.remui %11, %c256 : index
25+
%13 = tensor.extract %img[%12] : tensor<256xi16>
26+
27+
// subtract those two
28+
%14 = arith.subi %9, %13 : i16
29+
30+
// fetch img[x-1][y]
31+
%15 = arith.addi %x, %c-1 : index
32+
%16 = arith.muli %15, %c16 : index
33+
%18 = arith.addi %16, %y : index
34+
%19 = arith.remui %18, %c256 : index
35+
%20 = tensor.extract %img[%19] : tensor<256xi16>
36+
37+
// fetch img[x][y-1]
38+
%21 = arith.muli %x, %c16 : index
39+
%22 = arith.addi %y, %c-1 : index
40+
%23 = arith.addi %21, %22 : index
41+
%24 = arith.remui %23, %c256 : index
42+
%25 = tensor.extract %img[%24] : tensor<256xi16>
43+
44+
// subtract those two
45+
%26 = arith.subi %20, %25 : i16
46+
47+
// square each difference
48+
%27 = arith.muli %14, %14 : i16
49+
%28 = arith.muli %26, %26 : i16
50+
51+
// add the squares
52+
%29 = arith.addi %27, %28 : i16
53+
54+
// save to result[x][y]
55+
%30 = tensor.insert %29 into %imgy[%12] : tensor<256xi16>
56+
affine.yield %30: tensor<256xi16>
57+
}
58+
affine.yield %1 : tensor<256xi16>
59+
}
60+
return %r : tensor<256xi16>
61+
}

tests/Examples/common/roberts_cross_64x64.mlir

Lines changed: 0 additions & 61 deletions
This file was deleted.

tests/Examples/lattigo/bgv/roberts_cross/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ heir_lattigo_lib(
1313
"--mlir-to-bgv=ciphertext-degree=4096 plaintext-modulus=536903681",
1414
"--scheme-to-lattigo",
1515
],
16-
mlir_src = "@heir//tests/Examples/common:roberts_cross_64x64.mlir",
16+
mlir_src = "@heir//tests/Examples/common:roberts_cross_16x16.mlir",
1717
)
1818

1919
# For Google-internal reasons we must separate the go_test rules from the macro

tests/Examples/lattigo/bgv/roberts_cross/roberts_cross_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,34 @@ import (
77
func TestBinops(t *testing.T) {
88
evaluator, params, ecd, enc, dec := roberts_cross__configure()
99

10-
input := make([]int16, 4096)
11-
expected := make([]int16, 4096)
10+
input := make([]int16, 256)
11+
expected := make([]int16, 256)
1212

13-
for i := 0; i < 4096; i++ {
13+
for i := 0; i < 256; i++ {
1414
input[i] = int16(i)
1515
}
1616

17-
for row := 0; row < 64; row++ {
18-
for col := 0; col < 64; col++ {
19-
xY := (row*64 + col) % 4096
20-
xYm1 := (row*64 + col - 1) % 4096
21-
xm1Y := ((row-1)*64 + col) % 4096
22-
xm1Ym1 := ((row-1)*64 + col - 1) % 4096
17+
for row := 0; row < 16; row++ {
18+
for col := 0; col < 16; col++ {
19+
xY := (row*16 + col) % 256
20+
xYm1 := (row*16 + col - 1) % 256
21+
xm1Y := ((row-1)*16 + col) % 256
22+
xm1Ym1 := ((row-1)*16 + col - 1) % 256
2323

2424
if xYm1 < 0 {
25-
xYm1 += 4096
25+
xYm1 += 256
2626
}
2727
if xm1Y < 0 {
28-
xm1Y += 4096
28+
xm1Y += 256
2929
}
3030
if xm1Ym1 < 0 {
31-
xm1Ym1 += 4096
31+
xm1Ym1 += 256
3232
}
3333

3434
v1 := input[xm1Ym1] - input[xY]
3535
v2 := input[xm1Y] - input[xYm1]
3636
sum := v1*v1 + v2*v2
37-
expected[row*64+col] = sum
37+
expected[row*16+col] = sum
3838
}
3939
}
4040

@@ -44,7 +44,7 @@ func TestBinops(t *testing.T) {
4444

4545
result := roberts_cross__decrypt__result0(evaluator, params, ecd, dec, resultCt)
4646

47-
for i := 0; i < 4096; i++ {
47+
for i := 0; i < 256; i++ {
4848
if result[i] != expected[i] {
4949
t.Errorf("Decryption error at %d: %d != %d", i, result[i], expected[i])
5050
}

tests/Examples/openfhe/bgv/roberts_cross/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test")
55
package(default_applicable_licenses = ["@heir//:license"])
66

77
openfhe_end_to_end_test(
8-
name = "roberts_cross_64x64_test",
9-
generated_lib_header = "roberts_cross_64x64_lib.h",
8+
name = "roberts_cross_16x16_test",
9+
generated_lib_header = "roberts_cross_16x16_lib.h",
1010
heir_opt_flags = [
1111
"--annotate-module=backend=openfhe scheme=bgv",
1212
"--mlir-to-bgv=ciphertext-degree=4096 plaintext-modulus=536903681",
1313
"--scheme-to-openfhe",
1414
],
15-
mlir_src = "@heir//tests/Examples/common:roberts_cross_64x64.mlir",
15+
mlir_src = "@heir//tests/Examples/common:roberts_cross_16x16.mlir",
1616
tags = ["notap"],
1717
test_src = "roberts_cross_test.cpp",
1818
)

tests/Examples/openfhe/bgv/roberts_cross/roberts_cross_test.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "gtest/gtest.h" // from @googletest
66

77
// Generated headers (block clang-format from messing up order)
8-
#include "tests/Examples/openfhe/bgv/roberts_cross/roberts_cross_64x64_lib.h"
8+
#include "tests/Examples/openfhe/bgv/roberts_cross/roberts_cross_16x16_lib.h"
99

1010
using ::testing::ContainerEq;
1111

@@ -23,24 +23,24 @@ TEST(RobertsCrossTest, TestInput1) {
2323

2424
std::vector<int16_t> input;
2525
std::vector<int16_t> expected;
26-
input.reserve(4096);
27-
expected.reserve(4096);
26+
input.reserve(256);
27+
expected.reserve(256);
2828

29-
for (int i = 0; i < 4096; ++i) {
29+
for (int i = 0; i < 256; ++i) {
3030
input.push_back(i);
3131
}
3232

33-
for (int row = 0; row < 64; ++row) {
34-
for (int col = 0; col < 64; ++col) {
33+
for (int row = 0; row < 16; ++row) {
34+
for (int col = 0; col < 16; ++col) {
3535
// (img[x-1][y-1] - img[x][y])^2 + (img[x-1][y] - img[x][y-1])^2
36-
int xY = (row * 64 + col) % 4096;
37-
int xYm1 = (row * 64 + col - 1) % 4096;
38-
int xm1Y = ((row - 1) * 64 + col) % 4096;
39-
int xm1Ym1 = ((row - 1) * 64 + col - 1) % 4096;
40-
41-
if (xYm1 < 0) xYm1 += 4096;
42-
if (xm1Y < 0) xm1Y += 4096;
43-
if (xm1Ym1 < 0) xm1Ym1 += 4096;
36+
int xY = (row * 16 + col) % 256;
37+
int xYm1 = (row * 16 + col - 1) % 256;
38+
int xm1Y = ((row - 1) * 16 + col) % 256;
39+
int xm1Ym1 = ((row - 1) * 16 + col - 1) % 256;
40+
41+
if (xYm1 < 0) xYm1 += 256;
42+
if (xm1Y < 0) xm1Y += 256;
43+
if (xm1Ym1 < 0) xm1Ym1 += 256;
4444

4545
int16_t v1 = (input[xm1Ym1] - input[xY]);
4646
int16_t v2 = (input[xm1Y] - input[xYm1]);

tests/Examples/plaintext/roberts_cross/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ plaintext_test(
3333
heir_opt_flags = [
3434
"--mlir-to-plaintext-backend=plaintext-size=4096",
3535
],
36-
mlir_src = "@heir//tests/Examples/common:roberts_cross_64x64.mlir",
36+
mlir_src = "@heir//tests/Examples/common:roberts_cross_16x16.mlir",
3737
deps = [
3838
":roberts_cross_test",
3939
"@heir//tests/Examples/plaintext:memrefCopy",
@@ -46,7 +46,7 @@ plaintext_test(
4646
heir_opt_flags = [
4747
"--mlir-to-plaintext-backend=plaintext-size=4096 plaintext-modulus=536903681",
4848
],
49-
mlir_src = "@heir//tests/Examples/common:roberts_cross_64x64.mlir",
49+
mlir_src = "@heir//tests/Examples/common:roberts_cross_16x16.mlir",
5050
deps = [
5151
":roberts_cross_mod_test",
5252
"@heir//tests/Examples/plaintext:memrefCopy",
@@ -59,7 +59,7 @@ plaintext_test(
5959
heir_opt_flags = [
6060
"--mlir-to-plaintext-backend=plaintext-size=4096 plaintext-modulus=786433",
6161
],
62-
mlir_src = "@heir//tests/Examples/common:roberts_cross_64x64.mlir",
62+
mlir_src = "@heir//tests/Examples/common:roberts_cross_16x16.mlir",
6363
deps = [
6464
":roberts_cross_mod_failure_test",
6565
"@heir//tests/Examples/plaintext:memrefCopy",

tests/Transforms/heir_simd_vectorizer/BUILD

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,5 @@ glob_lit_tests(
99
name = "all_tests",
1010
data = ["@heir//tests:test_utilities"],
1111
driver = "@heir//tests:run_lit.sh",
12-
size_override = {
13-
"box_blur_64x64.mlir": "large",
14-
"roberts_cross_64x64.mlir": "enormous",
15-
"gx_kernel_64x64.mlir": "large",
16-
},
17-
tags_override = {
18-
"gx_kernel_64x64.mlir": [
19-
"nofastbuild",
20-
"notap",
21-
"manual",
22-
],
23-
},
2412
test_file_exts = ["mlir"],
2513
)

tests/Transforms/heir_simd_vectorizer/box_blur_64x64.mlir renamed to tests/Transforms/heir_simd_vectorizer/box_blur_16x16.mlir

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,60 +3,60 @@
33

44
module {
55
// CHECK: @box_blur
6-
// CHECK-SAME: %[[arg0:.*]]: !secret.secret<tensor<4096xi16>>) -> !secret.secret<tensor<4096xi16>> {
7-
// CHECK-DAG: %[[c127:.*]] = arith.constant 127 : index
8-
// CHECK-DAG: %[[c3968:.*]] = arith.constant 3968 : index
9-
// CHECK-DAG: %[[c4032:.*]] = arith.constant 4032 : index
10-
// CHECK-DAG: %[[c63:.*]] = arith.constant 63 : index
11-
// CHECK-DAG: %[[c65:.*]] = arith.constant 65 : index
12-
// CHECK-NEXT: %[[v0:.*]] = secret.generic(%[[arg0]]: !secret.secret<tensor<4096xi16>>) {
13-
// CHECK-NEXT: ^body(%[[arg1:.*]]: tensor<4096xi16>):
14-
// CHECK-NEXT: %[[v1:.*]] = tensor_ext.rotate %[[arg1]], %[[c3968]]
15-
// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[arg1]], %[[c4032]]
6+
// CHECK-SAME: %[[arg0:.*]]: !secret.secret<tensor<256xi16>>) -> !secret.secret<tensor<256xi16>> {
7+
// CHECK-DAG: %[[c31:.*]] = arith.constant 31 : index
8+
// CHECK-DAG: %[[c240:.*]] = arith.constant 240 : index
9+
// CHECK-DAG: %[[c224:.*]] = arith.constant 224 : index
10+
// CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
11+
// CHECK-DAG: %[[c17:.*]] = arith.constant 17 : index
12+
// CHECK-NEXT: %[[v0:.*]] = secret.generic(%[[arg0]]: !secret.secret<tensor<256xi16>>) {
13+
// CHECK-NEXT: ^body(%[[arg1:.*]]: tensor<256xi16>):
14+
// CHECK-NEXT: %[[v1:.*]] = tensor_ext.rotate %[[arg1]], %[[c224]]
15+
// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[arg1]], %[[c240]]
1616
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]]
1717
// CHECK-NEXT: %[[v4:.*]] = arith.addi %[[v3]], %[[arg1]]
18-
// CHECK-NEXT: %[[v5:.*]] = tensor_ext.rotate %[[v4]], %[[c63]]
18+
// CHECK-NEXT: %[[v5:.*]] = tensor_ext.rotate %[[v4]], %[[c15]]
1919
// CHECK-NEXT: %[[v6:.*]] = arith.addi %[[v5]], %[[v2]]
2020
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v6]], %[[arg1]]
21-
// CHECK-NEXT: %[[v8:.*]] = tensor_ext.rotate %[[v7]], %[[c63]]
22-
// CHECK-NEXT: %[[v9:.*]] = tensor_ext.rotate %[[arg1]], %[[c127]]
21+
// CHECK-NEXT: %[[v8:.*]] = tensor_ext.rotate %[[v7]], %[[c15]]
22+
// CHECK-NEXT: %[[v9:.*]] = tensor_ext.rotate %[[arg1]], %[[c31]]
2323
// CHECK-NEXT: %[[v10:.*]] = arith.addi %[[v8]], %[[v9]]
2424
// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v10]], %[[arg1]]
25-
// CHECK-NEXT: %[[v12:.*]] = tensor_ext.rotate %[[v11]], %[[c3968]]
25+
// CHECK-NEXT: %[[v12:.*]] = tensor_ext.rotate %[[v11]], %[[c224]]
2626
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v12]], %[[v2]]
2727
// CHECK-NEXT: %[[v14:.*]] = arith.addi %[[v13]], %[[arg1]]
28-
// CHECK-NEXT: %[[v15:.*]] = tensor_ext.rotate %[[v14]], %[[c65]]
28+
// CHECK-NEXT: %[[v15:.*]] = tensor_ext.rotate %[[v14]], %[[c17]]
2929
// CHECK-NEXT: secret.yield %[[v15]]
30-
// CHECK-NEXT: } -> !secret.secret<tensor<4096xi16>>
30+
// CHECK-NEXT: } -> !secret.secret<tensor<256xi16>>
3131
// CHECK-NEXT: return %[[v0]]
3232

33-
func.func @box_blur(%arg0: tensor<4096xi16>) -> tensor<4096xi16> {
34-
%c4096 = arith.constant 4096 : index
35-
%c64 = arith.constant 64 : index
36-
%0 = affine.for %x = 0 to 64 iter_args(%arg0_x = %arg0) -> (tensor<4096xi16>) {
37-
%1 = affine.for %y = 0 to 64 iter_args(%arg0_y = %arg0_x) -> (tensor<4096xi16>) {
33+
func.func @box_blur(%arg0: tensor<256xi16>) -> tensor<256xi16> {
34+
%c256 = arith.constant 256 : index
35+
%c16 = arith.constant 16 : index
36+
%0 = affine.for %x = 0 to 16 iter_args(%arg0_x = %arg0) -> (tensor<256xi16>) {
37+
%1 = affine.for %y = 0 to 16 iter_args(%arg0_y = %arg0_x) -> (tensor<256xi16>) {
3838
%c0_si16 = arith.constant 0 : i16
3939
%2 = affine.for %j = -1 to 2 iter_args(%value_j = %c0_si16) -> (i16) {
4040
%6 = affine.for %i = -1 to 2 iter_args(%value_i = %value_j) -> (i16) {
4141
%7 = arith.addi %x, %i : index
42-
%8 = arith.muli %7, %c64 : index
42+
%8 = arith.muli %7, %c16 : index
4343
%9 = arith.addi %y, %j : index
4444
%10 = arith.addi %8, %9 : index
45-
%11 = arith.remui %10, %c4096 : index
46-
%12 = tensor.extract %arg0[%11] : tensor<4096xi16>
45+
%11 = arith.remui %10, %c256 : index
46+
%12 = tensor.extract %arg0[%11] : tensor<256xi16>
4747
%13 = arith.addi %value_i, %12 : i16
4848
affine.yield %13 : i16
4949
}
5050
affine.yield %6 : i16
5151
}
52-
%3 = arith.muli %c64, %x : index
52+
%3 = arith.muli %c16, %x : index
5353
%4 = arith.addi %3, %y : index
54-
%5 = arith.remui %4, %c4096 : index
55-
%6 = tensor.insert %2 into %arg0_y[%5] : tensor<4096xi16>
56-
affine.yield %6 : tensor<4096xi16>
54+
%5 = arith.remui %4, %c256 : index
55+
%6 = tensor.insert %2 into %arg0_y[%5] : tensor<256xi16>
56+
affine.yield %6 : tensor<256xi16>
5757
}
58-
affine.yield %1 : tensor<4096xi16>
58+
affine.yield %1 : tensor<256xi16>
5959
}
60-
return %0 : tensor<4096xi16>
60+
return %0 : tensor<256xi16>
6161
}
6262
}

0 commit comments

Comments
 (0)