Skip to content

Commit 72cd120

Browse files
asraacopybara-github
authored andcommitted
cleanup: remove old LWE types
Part of #1199 PiperOrigin-RevId: 786339637
1 parent 292aa35 commit 72cd120

File tree

27 files changed

+629
-996
lines changed

27 files changed

+629
-996
lines changed

lib/Dialect/CGGI/IR/CGGIAttributes.td

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,6 @@ class CGGI_Attr<string name, string attrMnemonic, list<Trait> traits = []>
1313
let assemblyFormat = "`<` struct(params) `>`";
1414
}
1515

16-
def CGGI_CGGIParams : CGGI_Attr<"CGGIParams", "cggi_params"> {
17-
// TODO(#276): migrate the gadget params
18-
// to lwe dialect?
19-
let parameters = (ins
20-
"::mlir::heir::lwe::RLWEParamsAttr": $rlweParams,
21-
"unsigned": $bsk_noise_variance,
22-
"unsigned": $bsk_gadget_base_log,
23-
"unsigned": $bsk_gadget_num_levels,
24-
"unsigned": $ksk_noise_variance,
25-
"unsigned": $ksk_gadget_base_log,
26-
"unsigned": $ksk_gadget_num_levels
27-
);
28-
}
29-
3016
def CGGI_CGGIBoolGates : CGGI_Attr<"CGGIBoolGates", "cggi_bool_gates"> {
3117
let summary = "An attribute containing an array of strings to store bool gates";
3218

lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,6 @@ namespace mlir::heir::lwe {
4141

4242
ToOpenfheTypeConverter::ToOpenfheTypeConverter(MLIRContext *ctx) {
4343
addConversion([](Type type) { return type; });
44-
addConversion([ctx](lwe::RLWEPublicKeyType type) -> Type {
45-
return openfhe::PublicKeyType::get(ctx);
46-
});
47-
addConversion([ctx](lwe::RLWESecretKeyType type) -> Type {
48-
return openfhe::PrivateKeyType::get(ctx);
49-
});
5044
addConversion([ctx](lwe::NewLWEPublicKeyType type) -> Type {
5145
return openfhe::PublicKeyType::get(ctx);
5246
});

lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.cpp

Lines changed: 53 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,16 @@ class CiphertextTypeConverter : public TypeConverter {
6161
auto ring = type.getRing();
6262
auto polyTy = polynomial::PolynomialType::get(ctx, ring);
6363

64-
return RankedTensorType::get({2}, polyTy);
64+
// TODO(#2045): Use size information (number of polynomials) from the LWE
65+
// key type instead of hardcoding to 1.
66+
return RankedTensorType::get({1}, polyTy);
6567
});
6668
addConversion([ctx](lwe::NewLWEPublicKeyType type) -> Type {
6769
auto ring = type.getRing();
6870
auto polyTy = polynomial::PolynomialType::get(ctx, ring);
6971

72+
// TODO(#2045): Use size information (number of polynomials) from the LWE
73+
// key type instead of hardcoding to 2.
7074
return RankedTensorType::get({2}, polyTy);
7175
});
7276
}
@@ -118,7 +122,13 @@ struct ConvertRLWEDecrypt : public OpConversionPattern<RLWEDecryptOp> {
118122
builder.create<polynomial::MulOp>(extractSecretKeyOp, extractOp0);
119123
auto plaintext = builder.create<polynomial::AddOp>(index1sk, extractOp1);
120124

121-
rewriter.replaceOp(op, plaintext);
125+
// Cast to the plaintext space types.
126+
auto plaintextPolyType = cast<polynomial::PolynomialType>(
127+
typeConverter->convertType(op.getOutput().getType()));
128+
auto plaintextMod =
129+
builder.create<polynomial::ModSwitchOp>(plaintextPolyType, plaintext);
130+
131+
rewriter.replaceOp(op, plaintextMod);
122132
return success();
123133
}
124134
};
@@ -136,57 +146,7 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
136146
auto input = adaptor.getInput();
137147
auto key = adaptor.getKey();
138148

139-
// TODO (#785): Migrate to new LWE types and plaintext modulus.
140-
auto inputT = cast<lwe::RLWEPlaintextType>(op.getInput().getType());
141-
auto inputEncoding = inputT.getEncoding();
142-
auto cleartextBitwidthOrFailure =
143-
llvm::TypeSwitch<Attribute, FailureOr<int>>(inputEncoding)
144-
.Case<lwe::BitFieldEncodingAttr,
145-
lwe::UnspecifiedBitFieldEncodingAttr>(
146-
[](auto attr) -> FailureOr<int> {
147-
return attr.getCleartextBitwidth();
148-
})
149-
.Default([](Attribute attr) -> FailureOr<int> {
150-
llvm_unreachable(
151-
"Unsupported encoding attribute for cleartext bitwidth");
152-
return failure();
153-
});
154-
auto cleartextStartOrFailure =
155-
llvm::TypeSwitch<Attribute, FailureOr<int>>(inputEncoding)
156-
.Case<lwe::BitFieldEncodingAttr>([](auto attr) -> FailureOr<int> {
157-
return attr.getCleartextStart();
158-
})
159-
.Case<lwe::UnspecifiedBitFieldEncodingAttr>(
160-
[](auto attr) -> FailureOr<int> {
161-
llvm_unreachable(
162-
"Upsecified Bit Field Encoding Attribute for cleartext "
163-
"start");
164-
return failure();
165-
})
166-
.Default([](Attribute attr) -> FailureOr<int> {
167-
llvm_unreachable(
168-
"Unsupported encoding attribute for cleartext start");
169-
return failure();
170-
});
171-
172-
// Should return failure if cleartextBitwidth or cleartextStart fail.
173-
if (failed(cleartextBitwidthOrFailure) || failed(cleartextStartOrFailure)) {
174-
return failure();
175-
}
176-
auto cleartextBitwidth = cleartextBitwidthOrFailure.value();
177-
auto cleartextStart = cleartextStartOrFailure.value();
178-
179-
// Check that cleartext_start = cleartext_bitwidth for BGV encryption.
180-
if (cleartextBitwidth != cleartextStart) {
181-
// TODO (#882): Add support for other encryption schemes besides BGV. Left
182-
// as future work.
183-
op.emitError() << "`lwe.rlwe_encrypt` expects BGV encryption"
184-
<< " with cleartext_start = cleartext_bitwidth, but"
185-
<< " found cleartext_start = " << cleartextStart
186-
<< " and cleartext_bitwidth = " << cleartextBitwidth
187-
<< ".";
188-
return failure();
189-
}
149+
auto outputT = cast<lwe::NewLWECiphertextType>(op.getOutput().getType());
190150

191151
auto isPublicKey =
192152
llvm::TypeSwitch<Type, bool>(op.getKey().getType())
@@ -200,22 +160,16 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
200160
return false;
201161
});
202162

163+
// TODO (#882): Add support for other encryption schemes besides BGV. Left
164+
// as future work.
203165
ImplicitLocOpBuilder builder(loc, rewriter);
204166

205167
auto index0 = builder.create<arith::ConstantIndexOp>(0);
206-
auto dimension =
207-
inputT.getRing().getPolynomialModulus().getPolynomial().getDegree();
208168

209-
auto coefficientType = inputT.getRing().getCoefficientType();
210-
auto modArithType = dyn_cast<mod_arith::ModArithType>(coefficientType);
211-
if (!modArithType) {
212-
op.emitError() << "Unsupported coefficient type: " << coefficientType;
213-
return failure();
214-
}
215-
216-
Type tensorEltTy = modArithType.getModulus().getType();
217-
auto tensorParams = RankedTensorType::get({dimension}, tensorEltTy);
218-
auto modArithTensorType = RankedTensorType::get({dimension}, modArithType);
169+
auto outputTy = cast<RankedTensorType>(
170+
typeConverter->convertType(op.getOutput().getType()));
171+
auto outputPolyTy =
172+
cast<polynomial::PolynomialType>(outputTy.getElementType());
219173

220174
// TODO (#881): Add pass options to change the seed (which is currently
221175
// hardcoded to 0 with index).
@@ -237,13 +191,8 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
237191
builder.getI32IntegerAttr(-1), builder.getI32IntegerAttr(2));
238192

239193
// Generate random u polynomial from uniform random ternary distribution
240-
auto uTensor =
241-
builder.create<random::SampleOp>(tensorParams, uniformDistribution);
242-
// Convert the tensor of ints to a tensor of mod_arith, then a polynomial
243-
auto modArithUTensor =
244-
builder.create<mod_arith::EncapsulateOp>(modArithTensorType, uTensor);
245-
auto u = builder.create<polynomial::FromTensorOp>(modArithUTensor,
246-
inputT.getRing());
194+
auto u =
195+
builder.create<random::SampleOp>(outputPolyTy, uniformDistribution);
247196

248197
// Create a discrete Gaussian distribution
249198
auto discreteGaussianDistributionType = random::DistributionType::get(
@@ -263,38 +212,42 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
263212
tensor::ExtractOp publicKey1 =
264213
builder.create<tensor::ExtractOp>(key, ValueRange{index1});
265214

266-
// constantT is 2**(cleartextBitwidth), and is used for scalar
267-
// multiplication.
268-
// TODO(#876): Migrate to using the plaintext modulus of the encoding info
269-
// attributes.
270-
auto constantT = builder.create<mod_arith::ConstantOp>(
271-
modArithType, IntegerAttr::get(modArithType.getModulus().getType(),
272-
1 << cleartextBitwidth));
215+
// get plaintext modulus T in the output ring space. We assume the
216+
// plaintext type is a mod arith type.
217+
auto plaintextCoeffType =
218+
outputT.getPlaintextSpace().getRing().getCoefficientType();
219+
auto plaintextModArithType =
220+
dyn_cast<mod_arith::ModArithType>(plaintextCoeffType);
221+
if (!plaintextModArithType) {
222+
op.emitError() << "Unsupported plaintext coefficient type: "
223+
<< plaintextCoeffType;
224+
return failure();
225+
}
273226

274-
// generate random e0 polynomial from discrete gaussian distribution
275-
auto e0Tensor = builder.create<random::SampleOp>(
276-
tensorParams, discreteGaussianDistribution);
277-
auto modArithE0Tensor = builder.create<mod_arith::EncapsulateOp>(
278-
modArithTensorType, e0Tensor);
279-
auto e0 = builder.create<polynomial::FromTensorOp>(modArithE0Tensor,
280-
inputT.getRing());
227+
// create scalar constant T in the output coefficient space
228+
auto plaintextT = builder.create<mod_arith::ConstantOp>(
229+
plaintextModArithType, plaintextModArithType.getModulus());
230+
auto constantT = builder.create<mod_arith::ModSwitchOp>(
231+
outputPolyTy.getRing().getCoefficientType(), plaintextT);
281232

233+
// generate random e0 polynomial from discrete gaussian distribution
234+
auto e0 = builder.create<random::SampleOp>(outputPolyTy,
235+
discreteGaussianDistribution);
282236
// generate random e1 polynomial from discrete gaussian distribution
283-
auto e1Tensor = builder.create<random::SampleOp>(
284-
tensorParams, discreteGaussianDistribution);
285-
auto modArithE1Tensor = builder.create<mod_arith::EncapsulateOp>(
286-
modArithTensorType, e1Tensor);
287-
auto e1 = builder.create<polynomial::FromTensorOp>(modArithE1Tensor,
288-
inputT.getRing());
237+
auto e1 = builder.create<random::SampleOp>(outputPolyTy,
238+
discreteGaussianDistribution);
289239

290240
// TODO (#882): Other encryption schemes (e.g. CKKS) may multiply the
291241
// noise or key differently. Add support for those cases.
292242
// Computing ciphertext0 = publicKey0 * u + e0 *
293-
// constantT + input
243+
// constantT + cast(input)
294244
auto publicKey0U = builder.create<polynomial::MulOp>(publicKey0, u);
295245
auto tE0 = builder.create<polynomial::MulScalarOp>(e0, constantT);
296246
auto pK0UtE0 = builder.create<polynomial::AddOp>(publicKey0U, tE0);
297-
auto ciphertext0 = builder.create<polynomial::AddOp>(pK0UtE0, input);
247+
// cast from plaintext space to ciphertext space
248+
auto castInput =
249+
builder.create<polynomial::ModSwitchOp>(outputPolyTy, input);
250+
auto ciphertext0 = builder.create<polynomial::AddOp>(pK0UtE0, castInput);
298251

299252
// Computing ciphertext1 = publicKey1 * u + e1 * constantT
300253
auto publicKey1U = builder.create<polynomial::MulOp>(publicKey1, u);
@@ -317,20 +270,19 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
317270
}
318271

319272
// Generate random e polynomial from discrete gaussian distribution
320-
auto eTensor = builder.create<random::SampleOp>(
321-
tensorParams, discreteGaussianDistribution);
322-
auto modArithETensor =
323-
builder.create<mod_arith::EncapsulateOp>(modArithTensorType, eTensor);
324-
auto e = builder.create<polynomial::FromTensorOp>(modArithETensor,
325-
inputT.getRing());
273+
auto e = builder.create<random::SampleOp>(outputPolyTy,
274+
discreteGaussianDistribution);
326275

327276
// TODO (#882): Other encryption schemes (e.g. CKKS) may multiply the
328277
// noise or key differently. Add support for those cases.
329278
// ciphertext0 = u
330279
// Compute ciphertext1 = <u,s> + m + e
331280
auto keyPoly = builder.create<tensor::ExtractOp>(key, ValueRange{index0});
332281
auto us = builder.create<polynomial::MulOp>(u, keyPoly);
333-
auto usM = builder.create<polynomial::AddOp>(us, input);
282+
// cast from plaintext space to ciphertext space
283+
auto castInput =
284+
builder.create<polynomial::ModSwitchOp>(outputPolyTy, input);
285+
auto usM = builder.create<polynomial::AddOp>(us, castInput);
334286
auto ciphertext1 = builder.create<polynomial::AddOp>(usM, e);
335287

336288
// ciphertext = (u, ciphertext0)
@@ -530,11 +482,6 @@ struct ConvertRMulPlain : public OpConversionPattern<RMulPlainOp> {
530482

531483
struct LWEToPolynomial : public impl::LWEToPolynomialBase<LWEToPolynomial> {
532484
void runOnOperation() override {
533-
// TODO(#1199): Remove this emitError once the pass is fixed.
534-
getOperation()->emitError(
535-
"LWEToPolynomial conversion pass is broken. See #1199.");
536-
return;
537-
538485
MLIRContext *context = &getContext();
539486
auto *module = getOperation();
540487
CiphertextTypeConverter typeConverter(context);

lib/Dialect/LWE/IR/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ td_library(
6464
"LWEOps.td",
6565
"LWETraits.td",
6666
"LWETypes.td",
67-
"NewLWEAttributes.td",
68-
"NewLWETypes.td",
6967
],
7068
# include from the heir-root to enable fully-qualified include-paths
7169
includes = ["../../../.."],

lib/Dialect/LWE/IR/LWEAttributes.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -122,55 +122,6 @@ PlaintextSpaceAttr inferModulusSwitchOrRescaleOpPlaintextSpaceAttr(
122122
// Attribute Verification
123123
//===----------------------------------------------------------------------===//
124124

125-
LogicalResult BitFieldEncodingAttr::verifyEncoding(
126-
ArrayRef<int64_t> shape, Type elementType,
127-
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
128-
if (!elementType.isSignlessInteger()) {
129-
return emitError() << "Tensors with a bit_field_encoding must have "
130-
<< "signless integer element type, but found "
131-
<< elementType;
132-
}
133-
134-
unsigned plaintextBitwidth = elementType.getIntOrFloatBitWidth();
135-
unsigned cleartextBitwidth = getCleartextBitwidth();
136-
if (plaintextBitwidth < cleartextBitwidth)
137-
return emitError() << "The tensor element type's bitwidth "
138-
<< plaintextBitwidth
139-
<< " is too small to store the cleartext, "
140-
<< "which has bit width " << cleartextBitwidth << "";
141-
142-
auto cleartextStart = getCleartextStart();
143-
if (cleartextStart < 0 || cleartextStart >= plaintextBitwidth)
144-
return emitError() << "Attribute's cleartext starting bit index ("
145-
<< cleartextStart << ") is outside the legal range [0, "
146-
<< plaintextBitwidth - 1 << "]";
147-
148-
// It may be worth adding some sort of warning notification if the attribute
149-
// allocates no bits for noise, since this would be effectively useless for
150-
// FHE.
151-
return success();
152-
}
153-
154-
LogicalResult UnspecifiedBitFieldEncodingAttr::verifyEncoding(
155-
ArrayRef<int64_t> shape, Type elementType,
156-
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
157-
if (!elementType.isSignlessInteger()) {
158-
return emitError() << "Tensors with a bit_field_encoding must have "
159-
<< "signless integer element type, but found "
160-
<< elementType;
161-
}
162-
163-
unsigned plaintextBitwidth = elementType.getIntOrFloatBitWidth();
164-
unsigned cleartextBitwidth = getCleartextBitwidth();
165-
if (plaintextBitwidth < cleartextBitwidth)
166-
return emitError() << "The tensor element type's bitwidth "
167-
<< plaintextBitwidth
168-
<< " is too small to store the cleartext, "
169-
<< "which has bit width " << cleartextBitwidth << "";
170-
171-
return success();
172-
}
173-
174125
LogicalResult ApplicationDataAttr::verify(
175126
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
176127
mlir::Type messageType, Attribute overflow) {

0 commit comments

Comments
 (0)