Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions lib/Conversion/CGGIToOpenfhe/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "CGGIToOpenfhe",
srcs = ["CGGIToOpenfhe.cpp"],
hdrs = ["CGGIToOpenfhe.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Utils:ConversionUtils",
# "@heir//lib/Conversion:Utils",
"@heir//lib/Dialect/CGGI/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@heir//lib/Dialect/Comb/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=CGGIToOpenfhe",
],
"CGGIToOpenfhe.h.inc",
),
(
["-gen-pass-doc"],
"CGGIToOpenfhe.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "CGGIToOpenfhe.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
245 changes: 245 additions & 0 deletions lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#include "lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.h"

#include <iostream>
#include <numeric>

// #include "lib/Conversion/Utils.h"
#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEDialect.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h"
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

#define DEBUG_TYPE "cggi-to-openfhe"

namespace mlir::heir {

#define GEN_PASS_DEF_CGGITOOPENFHE
#include "lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.h.inc"

// Remove this class if no type conversions are necessary
class CGGIToOpenfheTypeConverter : public TypeConverter {
public:
CGGIToOpenfheTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
}
};

// Commented this out bc it throws a linker error since there's another one in
// CGGI -> TFHE Rust bool
static bool containsCGGIOps2(func::FuncOp func) {
auto walkResult = func.walk([&](Operation *op) {
if (llvm::isa<cggi::CGGIDialect>(op->getDialect()))
return WalkResult::interrupt();
return WalkResult::advance();
});
return walkResult.wasInterrupted();
}

// FIXME: I stole these two from the BGVToOpenfhe conversion; is there a better
// way to share code?
struct AddCryptoContextParam : public OpConversionPattern<func::FuncOp> {
AddCryptoContextParam(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs()
<< "Attempting to add context param to: " << op << "\n");

if (!containsCGGIOps2(op)) {
LLVM_DEBUG(llvm::dbgs() << "No crypto ops, skipping...\n");
return failure();
}

auto cryptoContextType = openfhe::BinFHEContextType::get(getContext());
FunctionType originalType = op.getFunctionType();
llvm::SmallVector<Type, 4> newTypes;
newTypes.reserve(originalType.getNumInputs() + 1);
newTypes.push_back(cryptoContextType);
for (auto t : originalType.getInputs()) {
newTypes.push_back(t);
}
auto newFuncType =
FunctionType::get(getContext(), newTypes, originalType.getResults());
rewriter.modifyOpInPlace(op, [&] {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
block.insertArgument(&block.getArguments().front(), cryptoContextType,
op.getLoc());
});

return success();
}
};

namespace {
FailureOr<Value> getContextualCryptoContext(Operation *op) {
Value cryptoContext = op->getParentOfType<func::FuncOp>()
.getBody()
.getBlocks()
.front()
.getArguments()
.front();
if (!mlir::isa<openfhe::BinFHEContextType>(cryptoContext.getType())) {
return op->emitOpError()
<< "Found CGGI op in a function without a public "
"key argument. Did the AddCryptoContextArg pattern fail to run?";
}
return cryptoContext;
}
} // namespace

struct AddCryptoContextArg : public OpConversionPattern<func::CallOp> {
AddCryptoContextArg(mlir::MLIRContext *context)
: OpConversionPattern<func::CallOp>(context, /* benefit= */ 2) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
func::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.modifyOpInPlace(op, [&] {
auto result = getContextualCryptoContext(op);
if (failed(result)) return;
auto context = result.value();
op->insertOperands(0, {context});
});

return success();
}
};

struct ConvertLutLincombOp : public OpConversionPattern<cggi::LutLinCombOp> {
ConvertLutLincombOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::LutLinCombOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::LutLinCombOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto result = getContextualCryptoContext(op);
if (failed(result)) return result;
auto cryptoContext = result.value();

auto inputs = op.getInputs();
auto coefficients = op.getCoefficients();

llvm::SmallVector<openfhe::LWEMulConstOp, 4> preppedInputs;
preppedInputs.reserve(coefficients.size());

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto scheme = b.create<openfhe::GetLWESchemeOp>(cryptoContext).getResult();

for (int i = 0; i < coefficients.size(); i++) {
preppedInputs.push_back(b.create<openfhe::LWEMulConstOp>(
scheme, inputs[i],
b.create<arith::ConstantOp>(b.getI64Type(),
b.getI64IntegerAttr(coefficients[i]))
.getResult()));
}

mlir::Value lutInput;

if (preppedInputs.size() > 1) {
openfhe::LWEAddOp sum = b.create<openfhe::LWEAddOp>(
scheme, preppedInputs[0].getResult(), preppedInputs[1].getResult());

for (int i = 2; i < preppedInputs.size(); i++) {
sum = b.create<openfhe::LWEAddOp>(scheme, sum, preppedInputs[i]);
}

lutInput = sum.getResult();
} else {
lutInput = preppedInputs[0].getResult();
}

// now create the LUT
// llvm::SmallSetVector<int, 8> lutBits;
// auto lutAttr = op.getLookupTableAttr();
// int width = coefficients.size();
// for (int i = 0; i < (1 << width); i++) {
// int index = 0;
// for (int j = 0; j < width; j++) {
// index += ((i >> (width - 1 - j)) & 1) * coefficients[j];
// }
// while (index < 0) index += 8;
// if ((lutAttr.getValue().getZExtValue() >> i) & 1) lutBits.insert(index
// % 8);
// }

// Extract message indices from the integer LUT mask, LSb-first.
llvm::SmallVector<int, 8> lutBits;
auto lutAttr = op.getLookupTableAttr();
uint64_t lutValue = lutAttr.getValue().getZExtValue();
for (int i = 0; i < 8; i++) {
if ((lutValue >> i) & 1ULL) {
lutBits.push_back(i);
}
}

auto makeLut = b.create<openfhe::MakeLutOp>(
cryptoContext, b.getDenseI32ArrayAttr(lutBits));
auto evalFunc = b.create<openfhe::EvalFuncOp>(
lutInput.getType(), cryptoContext, makeLut.getResult(), lutInput);
rewriter.replaceOp(op, evalFunc);
return success();
}
};

struct CGGIToOpenfhe : public impl::CGGIToOpenfheBase<CGGIToOpenfhe> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();
CGGIToOpenfheTypeConverter typeConverter(context);

RewritePatternSet patterns(context);
ConversionTarget target(*context);

addStructuralConversionPatterns(typeConverter, patterns, target);

target.addLegalDialect<openfhe::OpenfheDialect, memref::MemRefDialect,
lwe::LWEDialect>();

target.addIllegalOp<cggi::LutLinCombOp>();
target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp func) {
bool hasCryptoContext = func.getFunctionType().getNumInputs() > 0 &&
mlir::isa<openfhe::BinFHEContextType>(
*func.getFunctionType().getInputs().begin());
return hasCryptoContext || func.getName().starts_with("internal_generic");
});

target.addDynamicallyLegalOp<func::CallOp>([](func::CallOp call) {
bool hasCryptoContext = !call.getArgOperands().empty() &&
mlir::isa<openfhe::BinFHEContextType>(
*call.getArgOperands().getType().begin());
return hasCryptoContext;
});

// target.addIllegalDialect<cggi::CGGIDialect>();
patterns
.add<AddCryptoContextParam, AddCryptoContextArg, ConvertLutLincombOp>(
typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace mlir::heir
16 changes: 16 additions & 0 deletions lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef LIB_CONVERSION_CGGITOOPENFHE_CGGITOOPENFHE_H_
#define LIB_CONVERSION_CGGITOOPENFHE_CGGITOOPENFHE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir {

#define GEN_PASS_DECL
#include "lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.h.inc"

} // namespace mlir::heir

#endif // LIB_CONVERSION_CGGITOOPENFHE_CGGITOOPENFHE_H_
14 changes: 14 additions & 0 deletions lib/Conversion/CGGIToOpenfhe/CGGIToOpenfhe.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef LIB_CONVERSION_CGGITOOPENFHE_CGGITOOPENFHE_TD_
#define LIB_CONVERSION_CGGITOOPENFHE_CGGITOOPENFHE_TD_

include "mlir/Pass/PassBase.td"

def CGGIToOpenfhe : Pass<"cggi-to-openfhe"> {
let summary = "Lower `cggi` to `openfhe` dialect.";
let dependentDialects = [
"mlir::heir::cggi::CGGIDialect",
"mlir::heir::openfhe::OpenfheDialect",
];
}

#endif // LIB_CONVERSION_CGGITOOPENFHE_CGGITOOPENFHE_TD_
56 changes: 56 additions & 0 deletions lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,60 @@ def KeySwitchOp : Openfhe_Op<"key_switch", [

def BootstrapOp : Openfhe_UnaryTypeSwitchOp<"bootstrap"> { let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)"; }

// BinFHE operations

class Openfhe_lwe_BinaryOp<string mnemonic, list<Trait> traits = []>
: Openfhe_Op<mnemonic, traits # [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>,
]>{
let arguments = (ins
Openfhe_LWEScheme:$cryptoContext,
LWECiphertext:$lhs,
LWECiphertext:$rhs
);
let results = (outs LWECiphertext:$output);
}

def MakeLutOp : Openfhe_Op<"make_lut", [Pure]> {
let arguments = (ins
Openfhe_BinFHEContext:$cryptoContext,
DenseI32ArrayAttr:$values
);
let results = (outs Openfhe_LUT:$output);
}

def EvalFuncOp : Openfhe_Op<"eval_func", [Pure]> {
let arguments = (ins
Openfhe_BinFHEContext:$cryptoContext,
Openfhe_LUT:$lut,
LWECiphertext:$input
);
let results = (outs LWECiphertext:$output);
}

def GetLWESchemeOp : Openfhe_Op<"get_lwe_scheme", [Pure]> {
let summary = "Gets pointer to underlying LWE scheme.";
let arguments = (ins
Openfhe_BinFHEContext:$cryptoContext
);
let results = (outs Openfhe_LWEScheme:$scheme);
}

def LWEAddOp : Openfhe_lwe_BinaryOp<"lwe_add"> { let summary = "OpenFHE add operation of two LWE ciphertexts."; }
def LWESubOp : Openfhe_lwe_BinaryOp<"lwe_sub"> { let summary = "OpenFHE sub operation of two LWE ciphertexts."; }

def LWEMulConstOp : Openfhe_Op<"lwe_mul_const",[
Pure,
AllTypesMatch<["ciphertext", "output"]>
]> {
let summary = "OpenFHE mul operation of an LWE ciphertext and a constant.";
let arguments = (ins
Openfhe_LWEScheme:$cryptoContext,
LWECiphertext:$ciphertext,
I64:$constant
);
let results = (outs LWECiphertext:$output);
}

#endif // LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_TD_
Loading
Loading