Skip to content

Commit 1281797

Browse files
j2kuncopybara-github
authored andcommitted
Support EvalFastRotate in OpenFHE
- Adds new type for the precomputed digit decomposition - Adds new ops for the fast_rotation_precompute and fast_rotation - Adds a new pass openfhe-fast-rotation-precompute - Emits the new ops in the pke emitter Does not yet add this new pass to the pipeline by default (it was causing some failures and I have to investigate) Only supports the case for non-hybrid key switching. Lattigo's API is quite different from OpenFHE's so I don't think there will be a simple way to share code here. Maybe the step where we identify which rotations to hoist can be shared, but it seems worth waiting to decide how to do that until we have some meaningful analysis there. Fixes #1924 PiperOrigin-RevId: 801012821
1 parent c109c9e commit 1281797

File tree

16 files changed

+306
-30
lines changed

16 files changed

+306
-30
lines changed

lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ FailureOr<Value> getContextualCryptoContext(Operation* op) {
5353
auto result = getContextualArgFromFunc<openfhe::CryptoContextType>(op);
5454
if (failed(result)) {
5555
return op->emitOpError()
56-
<< "Found LWE op in a function without a public key argument."
56+
<< "Found LWE op in a function without a crypto context argument."
5757
" Did the AddCryptoContextArg pattern fail to run?";
5858
}
5959
return result.value();

lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
#ifndef LIB_DIALECT_LWE_CONVERSIONS_LWETOOPENFHE_LWETOOPENFHE_H_
22
#define LIB_DIALECT_LWE_CONVERSIONS_LWETOOPENFHE_LWETOOPENFHE_H_
33

4-
#include "lib/Dialect/LWE/IR/LWEOps.h"
4+
// IWYU pragma: begin_keep
5+
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
6+
// IWYU pragma: end_keep
7+
58
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
6-
#include "lib/Utils/ConversionUtils.h"
7-
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
8-
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
9-
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
10-
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
11-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
12-
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
10+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
11+
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
1312
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
1413

1514
namespace mlir::heir::lwe {

lib/Dialect/Openfhe/IR/OpenfheOps.td

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def LevelReduceOp : Openfhe_UnaryTypeSwitchOp<"level_reduce"> {
283283
);
284284
}
285285

286-
def RotOp : Openfhe_Op<"rot",[
286+
def RotOp : Openfhe_Op<"rot", [
287287
Pure,
288288
AllTypesMatch<["ciphertext", "output"]>
289289
]> {
@@ -321,4 +321,26 @@ def KeySwitchOp : Openfhe_Op<"key_switch", [
321321

322322
def BootstrapOp : Openfhe_UnaryTypeSwitchOp<"bootstrap"> { let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)"; }
323323

324+
def FastRotationPrecomputeOp : Openfhe_Op<"fast_rotation_precompute", [Pure]> {
325+
let arguments = (ins
326+
Openfhe_CryptoContext:$cryptoContext,
327+
LWECiphertext:$input
328+
);
329+
let results = (outs Openfhe_DigitDecomposition:$output);
330+
}
331+
332+
// TODO(#1924): support the "Ext" variant for hybrid key switching.
333+
def FastRotationOp : Openfhe_Op<"fast_rotation", [Pure,
334+
AllTypesMatch<["input", "output"]>
335+
]> {
336+
let arguments = (ins
337+
Openfhe_CryptoContext:$cryptoContext,
338+
LWECiphertext:$input,
339+
IndexAttr:$index,
340+
IndexAttr:$cyclotomicOrder,
341+
Openfhe_DigitDecomposition:$precomputedDigitDecomp
342+
);
343+
let results = (outs LWECiphertext:$output);
344+
}
345+
324346
#endif // LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_TD_

lib/Dialect/Openfhe/IR/OpenfheTypes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,11 @@ def Openfhe_CryptoContext : Openfhe_Type<"CryptoContext", "crypto_context"> {
5555
let asmName = "cc";
5656
}
5757

58+
// The otherwise unnamed type std::shared_ptr<std::vector<Element>> which is
59+
// the return type of EvalFastRotationPrecompute.
60+
def Openfhe_DigitDecomposition : Openfhe_Type<"DigitDecomposition", "digit_decomp"> {
61+
let summary = "A precomputed digit decomposition for for EvalFastRotation";
62+
let asmName = "digit_decomp";
63+
}
64+
5865
#endif // LIB_DIALECT_OPENFHE_IR_OPENFHETYPES_TD_

lib/Dialect/Openfhe/Transforms/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
deps = [
1515
":ConfigureCryptoContext",
1616
":CountAddAndKeySwitch",
17+
":FastRotationPrecompute",
1718
":pass_inc_gen",
1819
"@heir//lib/Dialect/Openfhe/IR:Dialect",
1920
],
@@ -61,6 +62,25 @@ cc_library(
6162
],
6263
)
6364

65+
cc_library(
66+
name = "FastRotationPrecompute",
67+
srcs = ["FastRotationPrecompute.cpp"],
68+
hdrs = [
69+
"FastRotationPrecompute.h",
70+
],
71+
deps = [
72+
":pass_inc_gen",
73+
"@heir//lib/Dialect/LWE/IR:Dialect",
74+
"@heir//lib/Dialect/Openfhe/IR:Dialect",
75+
"@heir//lib/Utils:ConversionUtils",
76+
"@llvm-project//llvm:Support",
77+
"@llvm-project//mlir:FuncDialect",
78+
"@llvm-project//mlir:IR",
79+
"@llvm-project//mlir:Pass",
80+
"@llvm-project//mlir:Support",
81+
],
82+
)
83+
6484
add_heir_transforms(
6585
header_filename = "Passes.h.inc",
6686
pass_name = "Openfhe",
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.h"
2+
3+
#include <cstdint>
4+
5+
#include "lib/Dialect/LWE/IR/LWETypes.h"
6+
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
7+
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
8+
#include "lib/Utils/ConversionUtils.h"
9+
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
10+
#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project
11+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
12+
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
13+
#include "llvm/include/llvm/Support/DebugLog.h" // from @llvm-project
14+
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
15+
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
16+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
18+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
19+
#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project
20+
21+
#define DEBUG_TYPE "fast-rotation-precompute"
22+
23+
namespace mlir {
24+
namespace heir {
25+
namespace openfhe {
26+
27+
#define GEN_PASS_DEF_FASTROTATIONPRECOMPUTE
28+
#include "lib/Dialect/Openfhe/Transforms/Passes.h.inc"
29+
30+
void processFunc(func::FuncOp funcOp, Value cryptoContext) {
31+
IRRewriter builder(funcOp->getContext());
32+
llvm::DenseMap<Value, llvm::SmallVector<RotOp>> ciphertextToRotateOps;
33+
llvm::DenseMap<Value, llvm::SmallDenseSet<int64_t>>
34+
ciphertextToDistinctRotations;
35+
funcOp->walk([&](RotOp op) {
36+
ciphertextToRotateOps[op.getCiphertext()].push_back(op);
37+
ciphertextToDistinctRotations[op.getCiphertext()].insert(
38+
op.getIndex().getValue().getZExtValue());
39+
});
40+
41+
for (auto const& [ciphertext, rots] : ciphertextToDistinctRotations) {
42+
// TODO(#744): is there a meaningful tradeoff for fast precompute?
43+
if (rots.size() < 2) {
44+
continue;
45+
}
46+
LLVM_DEBUG(llvm::dbgs() << "Found ciphertext with " << rots.size()
47+
<< " distinct rotations: " << ciphertext << "\n");
48+
49+
// Insert the precomputation op right after the ciphertext is defined. If
50+
// the ciphertext is a block argument, the precomputation op is inserted at
51+
// the beginning of the block.
52+
if (auto* definingOp = ciphertext.getDefiningOp()) {
53+
builder.setInsertionPointAfter(definingOp);
54+
} else {
55+
builder.setInsertionPointToStart(
56+
cast<BlockArgument>(ciphertext).getOwner());
57+
}
58+
59+
auto precomputeOp = FastRotationPrecomputeOp::create(
60+
builder, ciphertext.getLoc(), cryptoContext, ciphertext);
61+
62+
for (RotOp op : ciphertextToRotateOps[ciphertext]) {
63+
builder.setInsertionPoint(op);
64+
// cyclotomic order is 2*N where polynomial modulus is x^N + 1
65+
int cyclotomicOrder =
66+
2 * cast<lwe::LWECiphertextType>(ciphertext.getType())
67+
.getCiphertextSpace()
68+
.getRing()
69+
.getPolynomialModulus()
70+
.getPolynomial()
71+
.getDegree();
72+
auto fastRot = FastRotationOp::create(
73+
builder, op->getLoc(), op.getType(), op.getCryptoContext(),
74+
op.getCiphertext(), op.getIndex(),
75+
builder.getIndexAttr(cyclotomicOrder), precomputeOp.getResult());
76+
builder.replaceOp(op, fastRot);
77+
}
78+
}
79+
}
80+
81+
struct FastRotationPrecompute
82+
: impl::FastRotationPrecomputeBase<FastRotationPrecompute> {
83+
using FastRotationPrecomputeBase::FastRotationPrecomputeBase;
84+
85+
void runOnOperation() override {
86+
// We must process funcs separately so that rotations are not attempted to
87+
// be batched across function boundaries.
88+
getOperation()->walk([&](func::FuncOp op) -> WalkResult {
89+
auto result = getContextualArgFromFunc<openfhe::CryptoContextType>(op);
90+
if (failed(result)) {
91+
LDBG() << "Skipping func with no cryptocontex arg: " << op.getSymName()
92+
<< "\n";
93+
return WalkResult::advance();
94+
}
95+
Value cryptoContext = result.value();
96+
processFunc(op, cryptoContext);
97+
return WalkResult::advance();
98+
});
99+
}
100+
};
101+
} // namespace openfhe
102+
} // namespace heir
103+
} // namespace mlir
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef LIB_DIALECT_OPENFHE_TRANSFORMS_FASTROTATIONPRECOMPUTE_H_
2+
#define LIB_DIALECT_OPENFHE_TRANSFORMS_FASTROTATIONPRECOMPUTE_H_
3+
4+
// IWYU pragma: begin_keep
5+
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
6+
// IWYU pragma: end_keep
7+
8+
namespace mlir {
9+
namespace heir {
10+
namespace openfhe {
11+
12+
#define GEN_PASS_DECL_FASTROTATIONPRECOMPUTE
13+
#include "lib/Dialect/Openfhe/Transforms/Passes.h.inc"
14+
15+
} // namespace openfhe
16+
} // namespace heir
17+
} // namespace mlir
18+
19+
#endif // LIB_DIALECT_OPENFHE_TRANSFORMS_FASTROTATIONPRECOMPUTE_H_

lib/Dialect/Openfhe/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#ifndef LIB_DIALECT_OPENFHE_TRANSFORMS_PASSES_H_
22
#define LIB_DIALECT_OPENFHE_TRANSFORMS_PASSES_H_
33

4+
// IWYU pragma: begin_keep
45
#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h"
56
#include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h"
67
#include "lib/Dialect/Openfhe/Transforms/CountAddAndKeySwitch.h"
8+
#include "lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.h"
9+
// IWYU pragma: end_keep
710

811
namespace mlir {
912
namespace heir {

lib/Dialect/Openfhe/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,13 @@ def CountAddAndKeySwitch : Pass<"openfhe-count-add-and-key-switch"> {
7272
}];
7373
}
7474

75+
def FastRotationPrecompute : Pass<"openfhe-fast-rotation-precompute"> {
76+
let summary = "Identify and apply EvalFastRotation when possible.";
77+
let description = [{
78+
This pass identifies when a ciphertext is rotated by multiple different
79+
shifts, and replaces the `EvalRot` ops with `EvalFastRotationPrecompute`
80+
followed by `EvalFastRotate`.
81+
}];
82+
}
83+
7584
#endif // LIB_DIALECT_OPENFHE_TRANSFORMS_PASSES_TD_

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ BackendPipelineBuilder toOpenFhePipelineBuilder() {
405405
configureCryptoContextOptions.entryFunction = options.entryFunction;
406406
pm.addPass(
407407
openfhe::createConfigureCryptoContext(configureCryptoContextOptions));
408+
409+
// Hoist repeated rotations into EvalFastRotation(Precompute)
410+
// TODO(#1924): enable openfhe-fast-rotation-precompute in the pipeline
411+
// pm.addPass(openfhe::createFastRotationPrecompute());
408412
};
409413
}
410414

0 commit comments

Comments
 (0)