@@ -61,12 +61,16 @@ class CiphertextTypeConverter : public TypeConverter {
6161 auto ring = type.getRing ();
6262 auto polyTy = polynomial::PolynomialType::get (ctx, ring);
6363
64- return RankedTensorType::get ({2 }, polyTy);
64+ // TODO(#2045): Use size information (number of polynomials) from the LWE
65+ // key type instead of hardcoding to 1.
66+ return RankedTensorType::get ({1 }, polyTy);
6567 });
6668 addConversion ([ctx](lwe::NewLWEPublicKeyType type) -> Type {
6769 auto ring = type.getRing ();
6870 auto polyTy = polynomial::PolynomialType::get (ctx, ring);
6971
72+ // TODO(#2045): Use size information (number of polynomials) from the LWE
73+ // key type instead of hardcoding to 2.
7074 return RankedTensorType::get ({2 }, polyTy);
7175 });
7276 }
@@ -118,7 +122,13 @@ struct ConvertRLWEDecrypt : public OpConversionPattern<RLWEDecryptOp> {
118122 builder.create <polynomial::MulOp>(extractSecretKeyOp, extractOp0);
119123 auto plaintext = builder.create <polynomial::AddOp>(index1sk, extractOp1);
120124
121- rewriter.replaceOp (op, plaintext);
125+ // Cast to the plaintext space types.
126+ auto plaintextPolyType = cast<polynomial::PolynomialType>(
127+ typeConverter->convertType (op.getOutput ().getType ()));
128+ auto plaintextMod =
129+ builder.create <polynomial::ModSwitchOp>(plaintextPolyType, plaintext);
130+
131+ rewriter.replaceOp (op, plaintextMod);
122132 return success ();
123133 }
124134};
@@ -136,57 +146,7 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
136146 auto input = adaptor.getInput ();
137147 auto key = adaptor.getKey ();
138148
139- // TODO (#785): Migrate to new LWE types and plaintext modulus.
140- auto inputT = cast<lwe::RLWEPlaintextType>(op.getInput ().getType ());
141- auto inputEncoding = inputT.getEncoding ();
142- auto cleartextBitwidthOrFailure =
143- llvm::TypeSwitch<Attribute, FailureOr<int >>(inputEncoding)
144- .Case <lwe::BitFieldEncodingAttr,
145- lwe::UnspecifiedBitFieldEncodingAttr>(
146- [](auto attr) -> FailureOr<int > {
147- return attr.getCleartextBitwidth ();
148- })
149- .Default ([](Attribute attr) -> FailureOr<int > {
150- llvm_unreachable (
151- " Unsupported encoding attribute for cleartext bitwidth" );
152- return failure ();
153- });
154- auto cleartextStartOrFailure =
155- llvm::TypeSwitch<Attribute, FailureOr<int >>(inputEncoding)
156- .Case <lwe::BitFieldEncodingAttr>([](auto attr) -> FailureOr<int > {
157- return attr.getCleartextStart ();
158- })
159- .Case <lwe::UnspecifiedBitFieldEncodingAttr>(
160- [](auto attr) -> FailureOr<int > {
161- llvm_unreachable (
162- " Upsecified Bit Field Encoding Attribute for cleartext "
163- " start" );
164- return failure ();
165- })
166- .Default ([](Attribute attr) -> FailureOr<int > {
167- llvm_unreachable (
168- " Unsupported encoding attribute for cleartext start" );
169- return failure ();
170- });
171-
172- // Should return failure if cleartextBitwidth or cleartextStart fail.
173- if (failed (cleartextBitwidthOrFailure) || failed (cleartextStartOrFailure)) {
174- return failure ();
175- }
176- auto cleartextBitwidth = cleartextBitwidthOrFailure.value ();
177- auto cleartextStart = cleartextStartOrFailure.value ();
178-
179- // Check that cleartext_start = cleartext_bitwidth for BGV encryption.
180- if (cleartextBitwidth != cleartextStart) {
181- // TODO (#882): Add support for other encryption schemes besides BGV. Left
182- // as future work.
183- op.emitError () << " `lwe.rlwe_encrypt` expects BGV encryption"
184- << " with cleartext_start = cleartext_bitwidth, but"
185- << " found cleartext_start = " << cleartextStart
186- << " and cleartext_bitwidth = " << cleartextBitwidth
187- << " ." ;
188- return failure ();
189- }
149+ auto outputT = cast<lwe::NewLWECiphertextType>(op.getOutput ().getType ());
190150
191151 auto isPublicKey =
192152 llvm::TypeSwitch<Type, bool >(op.getKey ().getType ())
@@ -200,22 +160,16 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
200160 return false ;
201161 });
202162
163+ // TODO (#882): Add support for other encryption schemes besides BGV. Left
164+ // as future work.
203165 ImplicitLocOpBuilder builder (loc, rewriter);
204166
205167 auto index0 = builder.create <arith::ConstantIndexOp>(0 );
206- auto dimension =
207- inputT.getRing ().getPolynomialModulus ().getPolynomial ().getDegree ();
208168
209- auto coefficientType = inputT.getRing ().getCoefficientType ();
210- auto modArithType = dyn_cast<mod_arith::ModArithType>(coefficientType);
211- if (!modArithType) {
212- op.emitError () << " Unsupported coefficient type: " << coefficientType;
213- return failure ();
214- }
215-
216- Type tensorEltTy = modArithType.getModulus ().getType ();
217- auto tensorParams = RankedTensorType::get ({dimension}, tensorEltTy);
218- auto modArithTensorType = RankedTensorType::get ({dimension}, modArithType);
169+ auto outputTy = cast<RankedTensorType>(
170+ typeConverter->convertType (op.getOutput ().getType ()));
171+ auto outputPolyTy =
172+ cast<polynomial::PolynomialType>(outputTy.getElementType ());
219173
220174 // TODO (#881): Add pass options to change the seed (which is currently
221175 // hardcoded to 0 with index).
@@ -237,13 +191,8 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
237191 builder.getI32IntegerAttr (-1 ), builder.getI32IntegerAttr (2 ));
238192
239193 // Generate random u polynomial from uniform random ternary distribution
240- auto uTensor =
241- builder.create <random::SampleOp>(tensorParams, uniformDistribution);
242- // Convert the tensor of ints to a tensor of mod_arith, then a polynomial
243- auto modArithUTensor =
244- builder.create <mod_arith::EncapsulateOp>(modArithTensorType, uTensor);
245- auto u = builder.create <polynomial::FromTensorOp>(modArithUTensor,
246- inputT.getRing ());
194+ auto u =
195+ builder.create <random::SampleOp>(outputPolyTy, uniformDistribution);
247196
248197 // Create a discrete Gaussian distribution
249198 auto discreteGaussianDistributionType = random::DistributionType::get (
@@ -263,38 +212,42 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
263212 tensor::ExtractOp publicKey1 =
264213 builder.create <tensor::ExtractOp>(key, ValueRange{index1});
265214
266- // constantT is 2**(cleartextBitwidth), and is used for scalar
267- // multiplication.
268- // TODO(#876): Migrate to using the plaintext modulus of the encoding info
269- // attributes.
270- auto constantT = builder.create <mod_arith::ConstantOp>(
271- modArithType, IntegerAttr::get (modArithType.getModulus ().getType (),
272- 1 << cleartextBitwidth));
215+ // get plaintext modulus T in the output ring space. We assume the
216+ // plaintext type is a mod arith type.
217+ auto plaintextCoeffType =
218+ outputT.getPlaintextSpace ().getRing ().getCoefficientType ();
219+ auto plaintextModArithType =
220+ dyn_cast<mod_arith::ModArithType>(plaintextCoeffType);
221+ if (!plaintextModArithType) {
222+ op.emitError () << " Unsupported plaintext coefficient type: "
223+ << plaintextCoeffType;
224+ return failure ();
225+ }
273226
274- // generate random e0 polynomial from discrete gaussian distribution
275- auto e0Tensor = builder.create <random::SampleOp>(
276- tensorParams, discreteGaussianDistribution);
277- auto modArithE0Tensor = builder.create <mod_arith::EncapsulateOp>(
278- modArithTensorType, e0Tensor);
279- auto e0 = builder.create <polynomial::FromTensorOp>(modArithE0Tensor,
280- inputT.getRing ());
227+ // create scalar constant T in the output coefficient space
228+ auto plaintextT = builder.create <mod_arith::ConstantOp>(
229+ plaintextModArithType, plaintextModArithType.getModulus ());
230+ auto constantT = builder.create <mod_arith::ModSwitchOp>(
231+ outputPolyTy.getRing ().getCoefficientType (), plaintextT);
281232
233+ // generate random e0 polynomial from discrete gaussian distribution
234+ auto e0 = builder.create <random::SampleOp>(outputPolyTy,
235+ discreteGaussianDistribution);
282236 // generate random e1 polynomial from discrete gaussian distribution
283- auto e1Tensor = builder.create <random::SampleOp>(
284- tensorParams, discreteGaussianDistribution);
285- auto modArithE1Tensor = builder.create <mod_arith::EncapsulateOp>(
286- modArithTensorType, e1Tensor);
287- auto e1 = builder.create <polynomial::FromTensorOp>(modArithE1Tensor,
288- inputT.getRing ());
237+ auto e1 = builder.create <random::SampleOp>(outputPolyTy,
238+ discreteGaussianDistribution);
289239
290240 // TODO (#882): Other encryption schemes (e.g. CKKS) may multiply the
291241 // noise or key differently. Add support for those cases.
292242 // Computing ciphertext0 = publicKey0 * u + e0 *
293- // constantT + input
243+ // constantT + cast( input)
294244 auto publicKey0U = builder.create <polynomial::MulOp>(publicKey0, u);
295245 auto tE0 = builder.create <polynomial::MulScalarOp>(e0 , constantT);
296246 auto pK0UtE0 = builder.create <polynomial::AddOp>(publicKey0U, tE0);
297- auto ciphertext0 = builder.create <polynomial::AddOp>(pK0UtE0, input);
247+ // cast from plaintext space to ciphertext space
248+ auto castInput =
249+ builder.create <polynomial::ModSwitchOp>(outputPolyTy, input);
250+ auto ciphertext0 = builder.create <polynomial::AddOp>(pK0UtE0, castInput);
298251
299252 // Computing ciphertext1 = publicKey1 * u + e1 * constantT
300253 auto publicKey1U = builder.create <polynomial::MulOp>(publicKey1, u);
@@ -317,20 +270,19 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
317270 }
318271
319272 // Generate random e polynomial from discrete gaussian distribution
320- auto eTensor = builder.create <random::SampleOp>(
321- tensorParams, discreteGaussianDistribution);
322- auto modArithETensor =
323- builder.create <mod_arith::EncapsulateOp>(modArithTensorType, eTensor);
324- auto e = builder.create <polynomial::FromTensorOp>(modArithETensor,
325- inputT.getRing ());
273+ auto e = builder.create <random::SampleOp>(outputPolyTy,
274+ discreteGaussianDistribution);
326275
327276 // TODO (#882): Other encryption schemes (e.g. CKKS) may multiply the
328277 // noise or key differently. Add support for those cases.
329278 // ciphertext0 = u
330279 // Compute ciphertext1 = <u,s> + m + e
331280 auto keyPoly = builder.create <tensor::ExtractOp>(key, ValueRange{index0});
332281 auto us = builder.create <polynomial::MulOp>(u, keyPoly);
333- auto usM = builder.create <polynomial::AddOp>(us, input);
282+ // cast from plaintext space to ciphertext space
283+ auto castInput =
284+ builder.create <polynomial::ModSwitchOp>(outputPolyTy, input);
285+ auto usM = builder.create <polynomial::AddOp>(us, castInput);
334286 auto ciphertext1 = builder.create <polynomial::AddOp>(usM, e);
335287
336288 // ciphertext = (u, ciphertext0)
@@ -530,11 +482,6 @@ struct ConvertRMulPlain : public OpConversionPattern<RMulPlainOp> {
530482
531483struct LWEToPolynomial : public impl ::LWEToPolynomialBase<LWEToPolynomial> {
532484 void runOnOperation () override {
533- // TODO(#1199): Remove this emitError once the pass is fixed.
534- getOperation ()->emitError (
535- " LWEToPolynomial conversion pass is broken. See #1199." );
536- return ;
537-
538485 MLIRContext *context = &getContext ();
539486 auto *module = getOperation ();
540487 CiphertextTypeConverter typeConverter (context);
0 commit comments