@@ -1130,21 +1130,23 @@ class ConvertTensorInsertSlice
11301130 sizes.push_back (b.getIndexAttr (1 ));
11311131 sizes.push_back (b.getIndexAttr (slots));
11321132 SmallVector<OpFoldResult> strides (2 , b.getIndexAttr (1 ));
1133- Value extractedDest = tensor::ExtractSliceOp::create (
1133+ Operation* extractedDest = tensor::ExtractSliceOp::create (
11341134 b, op.getLoc (), cast<RankedTensorType>(convertedSource.getType ()),
11351135 adaptor.getDest (), ctOffsets, sizes, strides);
11361136 Operation* scalarMul = makeAppropriatelyTypedMulOp (
11371137 b, op.getLoc (), scalarMask, convertedSource, {getArithFMF (b)});
11381138 Operation* destMul = makeAppropriatelyTypedMulOp (
1139- b, op.getLoc (), destMask, extractedDest, {getArithFMF (b)});
1139+ b, op.getLoc (), destMask, extractedDest->getResult (0 ),
1140+ {getArithFMF (b)});
11401141 Operation* finalAdd =
11411142 makeAppropriatelyTypedAddOp (b, op.getLoc (), scalarMul->getResult (0 ),
11421143 destMul->getResult (0 ), {getArithFMF (b)});
11431144
11441145 // Insert the final result into the ciphertext at position ct.
11451146 Operation* insertOp = tensor::InsertSliceOp::create (
11461147 b, finalAdd->getResult (0 ), result, ctOffsets, sizes, strides);
1147- setMaterializedAttr ({scalarMul, destMul, finalAdd, insertOp});
1148+ setMaterializedAttr (
1149+ {extractedDest, scalarMul, destMul, finalAdd, insertOp});
11481150 result = insertOp->getResult (0 );
11491151 }
11501152
@@ -1472,6 +1474,123 @@ class ConvertTensorInsertLayout
14721474 }
14731475};
14741476
1477+ class ConvertTensorExtractSlice
1478+ : public ContextAwareOpConversionPattern<tensor::ExtractSliceOp> {
1479+ public:
1480+ using ContextAwareOpConversionPattern<
1481+ tensor::ExtractSliceOp>::ContextAwareOpConversionPattern;
1482+
1483+ LogicalResult secretSourceSecretResult (
1484+ tensor::ExtractSliceOp op, OpAdaptor adaptor,
1485+ ContextAwareConversionPatternRewriter& rewriter) const {
1486+ MLIRContext* ctx = op.getContext ();
1487+
1488+ FailureOr<Attribute> sourceLayoutResult =
1489+ getTypeConverter ()->getContextualAttr (adaptor.getSource ());
1490+ FailureOr<Attribute> resultLayoutResult =
1491+ getTypeConverter ()->getContextualAttr (op.getResult ());
1492+ LayoutAttr resultLayout = cast<LayoutAttr>(resultLayoutResult.value ());
1493+ IntegerRelation sourceRel =
1494+ cast<LayoutAttr>(sourceLayoutResult.value ()).getIntegerRelation ();
1495+ IntegerRelation resultRel = resultLayout.getIntegerRelation ();
1496+
1497+ // Compute the layout relation of the extract_slice operation.
1498+ auto extractSliceLayout =
1499+ getSliceExtractionRelation (op.getSourceType (), op.getResultType (),
1500+ SmallVector<int64_t >(op.getStaticOffsets ()),
1501+ SmallVector<int64_t >(op.getStaticSizes ()),
1502+ SmallVector<int64_t >(op.getStaticStrides ()));
1503+ if (failed (extractSliceLayout)) {
1504+ return op.emitError () << " failed to get layout for extract slice" ;
1505+ }
1506+
1507+ // Remap the source ciphertext semantic tensor to the result ciphertext
1508+ // semantic tensor layout. To do this, we compose the relations to traverse
1509+ // the following diagram, starting from the source ciphertext tensor to the
1510+ // extracted slice ciphertext tensor.
1511+ // Source tensor ─────────> Slice tensor
1512+ // \ │
1513+ // /│\ │
1514+ // │ \│ /
1515+ // │ \/
1516+ // Source ciphertext Slice ciphertext
1517+ // (ct, slot) (ct, slot)
1518+ sourceRel.inverse ();
1519+ sourceRel.compose (extractSliceLayout.value ());
1520+ sourceRel.compose (resultRel);
1521+
1522+ // tensor_ext.remap constrains its input and output types to be the same,
1523+ // i.e., remap occurs within one set of ciphertexts. The output of an
1524+ // extract_slice, however, may have a layout that has fewer ciphertexts
1525+ // in it. For example, extracting one row from a data-semantic matrix that
1526+ // is packed with one row per ciphertext would result in a single output
1527+ // ciphertext, and the expected layout of the result will reflect that.
1528+ // To bridge this gap, this kernel post-processes the remap's output to
1529+ // extract the subset ciphertexts relevant to the layout of the output
1530+ // slice.
1531+ LayoutAttr sliceLayoutAttr =
1532+ LayoutAttr::getFromIntegerRelation (ctx, sourceRel);
1533+ RankedTensorType sourceCiphertextSemanticType =
1534+ cast<RankedTensorType>(adaptor.getSource ().getType ());
1535+ auto remapSource = tensor_ext::RemapOp::create (
1536+ rewriter, op.getLoc (), sourceCiphertextSemanticType,
1537+ adaptor.getSource (), sliceLayoutAttr);
1538+
1539+ auto resultCiphertextSemanticType = cast<RankedTensorType>(
1540+ getTypeConverter ()->convertType (op.getResultType (), resultLayout));
1541+ SmallVector<OpFoldResult> strides (2 , rewriter.getIndexAttr (1 ));
1542+ SmallVector<OpFoldResult> offsets (2 , rewriter.getIndexAttr (0 ));
1543+ SmallVector<OpFoldResult> sizes;
1544+ sizes.push_back (
1545+ rewriter.getIndexAttr (resultCiphertextSemanticType.getDimSize (0 )));
1546+ sizes.push_back (
1547+ rewriter.getIndexAttr (resultCiphertextSemanticType.getDimSize (1 )));
1548+ auto extractRemap = tensor::ExtractSliceOp::create (
1549+ rewriter, op.getLoc (), resultCiphertextSemanticType,
1550+ remapSource.getResult (), offsets, sizes, strides);
1551+
1552+ setMaterializedAttr ({remapSource, extractRemap});
1553+ setAttributeAssociatedWith (extractRemap.getResult (), kLayoutAttrName ,
1554+ sliceLayoutAttr);
1555+ rewriter.replaceOp (op, extractRemap.getResult ());
1556+ return success ();
1557+ }
1558+
1559+ LogicalResult matchAndRewrite (
1560+ tensor::ExtractSliceOp op, OpAdaptor adaptor,
1561+ ContextAwareConversionPatternRewriter& rewriter) const final {
1562+ // Extract a secret slice from a secret tensor.
1563+ FailureOr<Attribute> sourceLayoutResult =
1564+ getTypeConverter ()->getContextualAttr (adaptor.getSource ());
1565+ FailureOr<Attribute> resultLayoutResult =
1566+ getTypeConverter ()->getContextualAttr (op.getResult ());
1567+
1568+ bool isSecretSource = succeeded (sourceLayoutResult);
1569+ bool isSecretResult = succeeded (resultLayoutResult);
1570+
1571+ if (isSecretSource && isSecretResult) {
1572+ return secretSourceSecretResult (op, adaptor, rewriter);
1573+ }
1574+
1575+ if (isSecretSource && !isSecretResult) {
1576+ return op.emitError ()
1577+ << " result tensor should have been assigned a layout "
1578+ " by layout-propagation" ;
1579+ }
1580+
1581+ if (!isSecretSource && isSecretResult) {
1582+ return op.emitError ()
1583+ << " source tensor should have been assigned a layout "
1584+ " by layout-propagation" ;
1585+ }
1586+
1587+ // cleartext scalar and cleartext tensor means this is a cleartext op
1588+ // that can be elided.
1589+ setMaterializedAttr (op);
1590+ return success ();
1591+ }
1592+ };
1593+
14751594class ConvertCollapseShape
14761595 : public ContextAwareOpConversionPattern<tensor::CollapseShapeOp> {
14771596 public:
@@ -1680,17 +1799,18 @@ struct ConvertToCiphertextSemantics
16801799 return isa<ModuleOp>(op) || hasMaterializedAttr (op);
16811800 });
16821801
1683- patterns.add <
1684- ConvertFunc, ConvertGeneric,
1685- // tensor_ext ops
1686- ConvertConvertLayout,
1687- // linalg ops
1688- ConvertLinalgReduce, ConvertLinalgMatvecLayout, ConvertLinalgConv2D,
1689- // tensor ops
1690- ConvertTensorExtractLayout, ConvertTensorInsertLayout,
1691- ConvertCollapseShape, ConvertExpandShape, ConvertTensorInsertSlice,
1692- // default
1693- ConvertAnyAddingMaterializedAttr>(typeConverter, context);
1802+ patterns.add <ConvertFunc, ConvertGeneric,
1803+ // tensor_ext ops
1804+ ConvertConvertLayout,
1805+ // linalg ops
1806+ ConvertLinalgReduce, ConvertLinalgMatvecLayout,
1807+ ConvertLinalgConv2D,
1808+ // tensor ops
1809+ ConvertTensorExtractLayout, ConvertTensorInsertLayout,
1810+ ConvertCollapseShape, ConvertExpandShape,
1811+ ConvertTensorInsertSlice, ConvertTensorExtractSlice,
1812+ // default
1813+ ConvertAnyAddingMaterializedAttr>(typeConverter, context);
16941814 patterns.add <ConvertAssignLayout>(typeConverter, context, ciphertextSize);
16951815
16961816 ConversionConfig config;
0 commit comments