|
| 1 | +#include "lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.h" |
| 2 | + |
| 3 | +#include <cstdint> |
| 4 | + |
| 5 | +#include "lib/Dialect/LWE/IR/LWETypes.h" |
| 6 | +#include "lib/Dialect/Openfhe/IR/OpenfheOps.h" |
| 7 | +#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" |
| 8 | +#include "lib/Utils/ConversionUtils.h" |
| 9 | +#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project |
| 10 | +#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project |
| 11 | +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project |
| 12 | +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project |
| 13 | +#include "llvm/include/llvm/Support/DebugLog.h" // from @llvm-project |
| 14 | +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| 15 | +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project |
| 16 | +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project |
| 17 | +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project |
| 18 | +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project |
| 19 | +#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project |
| 20 | + |
| 21 | +#define DEBUG_TYPE "fast-rotation-precompute" |
| 22 | + |
| 23 | +namespace mlir { |
| 24 | +namespace heir { |
| 25 | +namespace openfhe { |
| 26 | + |
| 27 | +#define GEN_PASS_DEF_FASTROTATIONPRECOMPUTE |
| 28 | +#include "lib/Dialect/Openfhe/Transforms/Passes.h.inc" |
| 29 | + |
| 30 | +void processFunc(func::FuncOp funcOp, Value cryptoContext) { |
| 31 | + IRRewriter builder(funcOp->getContext()); |
| 32 | + llvm::DenseMap<Value, llvm::SmallVector<RotOp>> ciphertextToRotateOps; |
| 33 | + llvm::DenseMap<Value, llvm::SmallDenseSet<int64_t>> |
| 34 | + ciphertextToDistinctRotations; |
| 35 | + funcOp->walk([&](RotOp op) { |
| 36 | + ciphertextToRotateOps[op.getCiphertext()].push_back(op); |
| 37 | + ciphertextToDistinctRotations[op.getCiphertext()].insert( |
| 38 | + op.getIndex().getValue().getZExtValue()); |
| 39 | + }); |
| 40 | + |
| 41 | + for (auto const& [ciphertext, rots] : ciphertextToDistinctRotations) { |
| 42 | + // TODO(#744): is there a meaningful tradeoff for fast precompute? |
| 43 | + if (rots.size() < 2) { |
| 44 | + continue; |
| 45 | + } |
| 46 | + LLVM_DEBUG(llvm::dbgs() << "Found ciphertext with " << rots.size() |
| 47 | + << " distinct rotations: " << ciphertext << "\n"); |
| 48 | + |
| 49 | + // Insert the precomputation op right after the ciphertext is defined. If |
| 50 | + // the ciphertext is a block argument, the precomputation op is inserted at |
| 51 | + // the beginning of the block. |
| 52 | + if (auto* definingOp = ciphertext.getDefiningOp()) { |
| 53 | + builder.setInsertionPointAfter(definingOp); |
| 54 | + } else { |
| 55 | + builder.setInsertionPointToStart( |
| 56 | + cast<BlockArgument>(ciphertext).getOwner()); |
| 57 | + } |
| 58 | + |
| 59 | + auto precomputeOp = FastRotationPrecomputeOp::create( |
| 60 | + builder, ciphertext.getLoc(), cryptoContext, ciphertext); |
| 61 | + |
| 62 | + for (RotOp op : ciphertextToRotateOps[ciphertext]) { |
| 63 | + builder.setInsertionPoint(op); |
| 64 | + // cyclotomic order is 2*N where polynomial modulus is x^N + 1 |
| 65 | + int cyclotomicOrder = |
| 66 | + 2 * cast<lwe::LWECiphertextType>(ciphertext.getType()) |
| 67 | + .getCiphertextSpace() |
| 68 | + .getRing() |
| 69 | + .getPolynomialModulus() |
| 70 | + .getPolynomial() |
| 71 | + .getDegree(); |
| 72 | + auto fastRot = FastRotationOp::create( |
| 73 | + builder, op->getLoc(), op.getType(), op.getCryptoContext(), |
| 74 | + op.getCiphertext(), op.getIndex(), |
| 75 | + builder.getIndexAttr(cyclotomicOrder), precomputeOp.getResult()); |
| 76 | + builder.replaceOp(op, fastRot); |
| 77 | + } |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +struct FastRotationPrecompute |
| 82 | + : impl::FastRotationPrecomputeBase<FastRotationPrecompute> { |
| 83 | + using FastRotationPrecomputeBase::FastRotationPrecomputeBase; |
| 84 | + |
| 85 | + void runOnOperation() override { |
| 86 | + // We must process funcs separately so that rotations are not attempted to |
| 87 | + // be batched across function boundaries. |
| 88 | + getOperation()->walk([&](func::FuncOp op) -> WalkResult { |
| 89 | + auto result = getContextualArgFromFunc<openfhe::CryptoContextType>(op); |
| 90 | + if (failed(result)) { |
| 91 | + LDBG() << "Skipping func with no cryptocontex arg: " << op.getSymName() |
| 92 | + << "\n"; |
| 93 | + return WalkResult::advance(); |
| 94 | + } |
| 95 | + Value cryptoContext = result.value(); |
| 96 | + processFunc(op, cryptoContext); |
| 97 | + return WalkResult::advance(); |
| 98 | + }); |
| 99 | + } |
| 100 | +}; |
| 101 | +} // namespace openfhe |
| 102 | +} // namespace heir |
| 103 | +} // namespace mlir |
0 commit comments