Skip to content

Commit 24c0a48

Browse files
asraacopybara-github
authored andcommitted
enable layout propagation for tensor.extract_slice
PiperOrigin-RevId: 827483403
1 parent 9e8e6f4 commit 24c0a48

File tree

8 files changed

+329
-40
lines changed

8 files changed

+329
-40
lines changed

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 134 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,21 +1130,23 @@ class ConvertTensorInsertSlice
11301130
sizes.push_back(b.getIndexAttr(1));
11311131
sizes.push_back(b.getIndexAttr(slots));
11321132
SmallVector<OpFoldResult> strides(2, b.getIndexAttr(1));
1133-
Value extractedDest = tensor::ExtractSliceOp::create(
1133+
Operation* extractedDest = tensor::ExtractSliceOp::create(
11341134
b, op.getLoc(), cast<RankedTensorType>(convertedSource.getType()),
11351135
adaptor.getDest(), ctOffsets, sizes, strides);
11361136
Operation* scalarMul = makeAppropriatelyTypedMulOp(
11371137
b, op.getLoc(), scalarMask, convertedSource, {getArithFMF(b)});
11381138
Operation* destMul = makeAppropriatelyTypedMulOp(
1139-
b, op.getLoc(), destMask, extractedDest, {getArithFMF(b)});
1139+
b, op.getLoc(), destMask, extractedDest->getResult(0),
1140+
{getArithFMF(b)});
11401141
Operation* finalAdd =
11411142
makeAppropriatelyTypedAddOp(b, op.getLoc(), scalarMul->getResult(0),
11421143
destMul->getResult(0), {getArithFMF(b)});
11431144

11441145
// Insert the final result into the ciphertext at position ct.
11451146
Operation* insertOp = tensor::InsertSliceOp::create(
11461147
b, finalAdd->getResult(0), result, ctOffsets, sizes, strides);
1147-
setMaterializedAttr({scalarMul, destMul, finalAdd, insertOp});
1148+
setMaterializedAttr(
1149+
{extractedDest, scalarMul, destMul, finalAdd, insertOp});
11481150
result = insertOp->getResult(0);
11491151
}
11501152

@@ -1472,6 +1474,123 @@ class ConvertTensorInsertLayout
14721474
}
14731475
};
14741476

1477+
class ConvertTensorExtractSlice
1478+
: public ContextAwareOpConversionPattern<tensor::ExtractSliceOp> {
1479+
public:
1480+
using ContextAwareOpConversionPattern<
1481+
tensor::ExtractSliceOp>::ContextAwareOpConversionPattern;
1482+
1483+
LogicalResult secretSourceSecretResult(
1484+
tensor::ExtractSliceOp op, OpAdaptor adaptor,
1485+
ContextAwareConversionPatternRewriter& rewriter) const {
1486+
MLIRContext* ctx = op.getContext();
1487+
1488+
FailureOr<Attribute> sourceLayoutResult =
1489+
getTypeConverter()->getContextualAttr(adaptor.getSource());
1490+
FailureOr<Attribute> resultLayoutResult =
1491+
getTypeConverter()->getContextualAttr(op.getResult());
1492+
LayoutAttr resultLayout = cast<LayoutAttr>(resultLayoutResult.value());
1493+
IntegerRelation sourceRel =
1494+
cast<LayoutAttr>(sourceLayoutResult.value()).getIntegerRelation();
1495+
IntegerRelation resultRel = resultLayout.getIntegerRelation();
1496+
1497+
// Compute the layout relation of the extract_slice operation.
1498+
auto extractSliceLayout =
1499+
getSliceExtractionRelation(op.getSourceType(), op.getResultType(),
1500+
SmallVector<int64_t>(op.getStaticOffsets()),
1501+
SmallVector<int64_t>(op.getStaticSizes()),
1502+
SmallVector<int64_t>(op.getStaticStrides()));
1503+
if (failed(extractSliceLayout)) {
1504+
return op.emitError() << "failed to get layout for extract slice";
1505+
}
1506+
1507+
// Remap the source ciphertext semantic tensor to the result ciphertext
1508+
// semantic tensor layout. To do this, we compose the relations to traverse
1509+
// the following diagram, starting from the source ciphertext tensor to the
1510+
// extracted slice ciphertext tensor.
1511+
// Source tensor ─────────> Slice tensor
1512+
// \ │
1513+
// /│\ │
1514+
// │ \│ /
1515+
// │ \/
1516+
// Source ciphertext Slice ciphertext
1517+
// (ct, slot) (ct, slot)
1518+
sourceRel.inverse();
1519+
sourceRel.compose(extractSliceLayout.value());
1520+
sourceRel.compose(resultRel);
1521+
1522+
// tensor_ext.remap constrains its input and output types to be the same,
1523+
// i.e., remap occurs within one set of ciphertexts. The output of an
1524+
// extract_slice, however, may have a layout that has fewer ciphertexts
1525+
// in it. For example, extracting one row from a data-semantic matrix that
1526+
// is packed with one row per ciphertext would result in a single output
1527+
// ciphertext, and the expected layout of the result will reflect that.
1528+
// To bridge this gap, this kernel post-processes the remap's output to
1529+
// extract the subset ciphertexts relevant to the layout of the output
1530+
// slice.
1531+
LayoutAttr sliceLayoutAttr =
1532+
LayoutAttr::getFromIntegerRelation(ctx, sourceRel);
1533+
RankedTensorType sourceCiphertextSemanticType =
1534+
cast<RankedTensorType>(adaptor.getSource().getType());
1535+
auto remapSource = tensor_ext::RemapOp::create(
1536+
rewriter, op.getLoc(), sourceCiphertextSemanticType,
1537+
adaptor.getSource(), sliceLayoutAttr);
1538+
1539+
auto resultCiphertextSemanticType = cast<RankedTensorType>(
1540+
getTypeConverter()->convertType(op.getResultType(), resultLayout));
1541+
SmallVector<OpFoldResult> strides(2, rewriter.getIndexAttr(1));
1542+
SmallVector<OpFoldResult> offsets(2, rewriter.getIndexAttr(0));
1543+
SmallVector<OpFoldResult> sizes;
1544+
sizes.push_back(
1545+
rewriter.getIndexAttr(resultCiphertextSemanticType.getDimSize(0)));
1546+
sizes.push_back(
1547+
rewriter.getIndexAttr(resultCiphertextSemanticType.getDimSize(1)));
1548+
auto extractRemap = tensor::ExtractSliceOp::create(
1549+
rewriter, op.getLoc(), resultCiphertextSemanticType,
1550+
remapSource.getResult(), offsets, sizes, strides);
1551+
1552+
setMaterializedAttr({remapSource, extractRemap});
1553+
setAttributeAssociatedWith(extractRemap.getResult(), kLayoutAttrName,
1554+
sliceLayoutAttr);
1555+
rewriter.replaceOp(op, extractRemap.getResult());
1556+
return success();
1557+
}
1558+
1559+
LogicalResult matchAndRewrite(
1560+
tensor::ExtractSliceOp op, OpAdaptor adaptor,
1561+
ContextAwareConversionPatternRewriter& rewriter) const final {
1562+
// Extract a secret slice from a secret tensor.
1563+
FailureOr<Attribute> sourceLayoutResult =
1564+
getTypeConverter()->getContextualAttr(adaptor.getSource());
1565+
FailureOr<Attribute> resultLayoutResult =
1566+
getTypeConverter()->getContextualAttr(op.getResult());
1567+
1568+
bool isSecretSource = succeeded(sourceLayoutResult);
1569+
bool isSecretResult = succeeded(resultLayoutResult);
1570+
1571+
if (isSecretSource && isSecretResult) {
1572+
return secretSourceSecretResult(op, adaptor, rewriter);
1573+
}
1574+
1575+
if (isSecretSource && !isSecretResult) {
1576+
return op.emitError()
1577+
<< "result tensor should have been assigned a layout "
1578+
"by layout-propagation";
1579+
}
1580+
1581+
if (!isSecretSource && isSecretResult) {
1582+
return op.emitError()
1583+
<< "source tensor should have been assigned a layout "
1584+
"by layout-propagation";
1585+
}
1586+
1587+
// cleartext scalar and cleartext tensor means this is a cleartext op
1588+
// that can be elided.
1589+
setMaterializedAttr(op);
1590+
return success();
1591+
}
1592+
};
1593+
14751594
class ConvertCollapseShape
14761595
: public ContextAwareOpConversionPattern<tensor::CollapseShapeOp> {
14771596
public:
@@ -1680,17 +1799,18 @@ struct ConvertToCiphertextSemantics
16801799
return isa<ModuleOp>(op) || hasMaterializedAttr(op);
16811800
});
16821801

1683-
patterns.add<
1684-
ConvertFunc, ConvertGeneric,
1685-
// tensor_ext ops
1686-
ConvertConvertLayout,
1687-
// linalg ops
1688-
ConvertLinalgReduce, ConvertLinalgMatvecLayout, ConvertLinalgConv2D,
1689-
// tensor ops
1690-
ConvertTensorExtractLayout, ConvertTensorInsertLayout,
1691-
ConvertCollapseShape, ConvertExpandShape, ConvertTensorInsertSlice,
1692-
// default
1693-
ConvertAnyAddingMaterializedAttr>(typeConverter, context);
1802+
patterns.add<ConvertFunc, ConvertGeneric,
1803+
// tensor_ext ops
1804+
ConvertConvertLayout,
1805+
// linalg ops
1806+
ConvertLinalgReduce, ConvertLinalgMatvecLayout,
1807+
ConvertLinalgConv2D,
1808+
// tensor ops
1809+
ConvertTensorExtractLayout, ConvertTensorInsertLayout,
1810+
ConvertCollapseShape, ConvertExpandShape,
1811+
ConvertTensorInsertSlice, ConvertTensorExtractSlice,
1812+
// default
1813+
ConvertAnyAddingMaterializedAttr>(typeConverter, context);
16941814
patterns.add<ConvertAssignLayout>(typeConverter, context, ciphertextSize);
16951815

16961816
ConversionConfig config;

lib/Transforms/LayoutPropagation/LayoutPropagation.cpp

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ struct LayoutPropagation : impl::LayoutPropagationBase<LayoutPropagation> {
139139
LogicalResult visitOperation(tensor::ExtractOp op);
140140
LogicalResult visitOperation(tensor::InsertOp op);
141141
LogicalResult visitOperation(tensor::InsertSliceOp op);
142+
LogicalResult visitOperation(tensor::ExtractSliceOp op);
142143

143144
// Determine if the operation arguments have compatible layouts for the
144145
// given op. If the check fails, the CompatibilityResult::compatible field
@@ -277,16 +278,8 @@ LogicalResult LayoutPropagation::visitOperation(Operation* op) {
277278
// affine ops
278279
.Case<affine::AffineForOp>([&](auto op) { return visitOperation(op); })
279280
// tensor ops
280-
.Case<tensor::ExtractOp, tensor::InsertOp, tensor::InsertSliceOp>(
281-
[&](auto op) { return visitOperation(op); })
282-
.Case<tensor::ExtractSliceOp>([&](auto op) {
283-
// TODO(#2028): Support tensor.extract_slice and tensor.insert_slice in
284-
// layout.
285-
return op->emitError()
286-
<< "Layout propagation not supported for this op";
287-
})
288-
// tensor ops
289-
.Case<CollapseShapeOp, ExpandShapeOp>(
281+
.Case<tensor::ExtractOp, tensor::InsertOp, tensor::InsertSliceOp,
282+
tensor::ExtractSliceOp, CollapseShapeOp, ExpandShapeOp>(
290283
[&](auto op) { return visitOperation(op); })
291284
// AddI, AddF, mgmt.* all pass the layout through unchanged.
292285
.Default([&](Operation* op) {
@@ -790,6 +783,53 @@ LogicalResult LayoutPropagation::visitOperation(tensor::InsertSliceOp op) {
790783
return success();
791784
}
792785

786+
LogicalResult LayoutPropagation::visitOperation(tensor::ExtractSliceOp op) {
787+
// Assign the induced layout from extracting a slice from the source tensor.
788+
if (!assignedLayouts.contains(op.getSource())) {
789+
return op->emitError() << "Source tensor has no assigned layout";
790+
}
791+
IntegerRelation sourceLayout =
792+
assignedLayouts.at(op.getSource()).getIntegerRelation();
793+
794+
FailureOr<IntegerRelation> maybeSliceExtractionLayout =
795+
getSliceExtractionRelation(op.getSourceType(), op.getResultType(),
796+
SmallVector<int64_t>(op.getStaticOffsets()),
797+
SmallVector<int64_t>(op.getStaticSizes()),
798+
SmallVector<int64_t>(op.getStaticStrides()));
799+
if (failed(maybeSliceExtractionLayout)) {
800+
return failure();
801+
}
802+
IntegerRelation sliceExtractionLayout = maybeSliceExtractionLayout.value();
803+
804+
// Compose the inverted slice extraction layout with the source layout to
805+
// get the result slice layout.
806+
sliceExtractionLayout.inverse();
807+
sliceExtractionLayout.compose(sourceLayout);
808+
// If the slice extracted was not at offset zero, then the resulting slice may
809+
// be indexed at a non-zero ciphertext. For example, imagine extracting a
810+
// slice out of the second ciphertext. Then computing the inverse of the slice
811+
// extraction layout and composing that with the source relation would mean
812+
// that the slice would map to the second ciphertext. But a slice extracted
813+
// from a tensor.extract_slice op is always indexed starting from zero.
814+
// Reindexing the the resulting relation to start from ciphertext zero.
815+
auto ctVarOffset =
816+
sliceExtractionLayout.getVarKindOffset(presburger::VarKind::Range);
817+
auto ctLowerBound = sliceExtractionLayout.getConstantBound64(
818+
presburger::BoundType::LB, ctVarOffset);
819+
if (!ctLowerBound) {
820+
return op.emitError() << "failed to get constant bound on ciphertext index";
821+
}
822+
auto zeroIndexedSliceLayout =
823+
shiftVar(sliceExtractionLayout, ctVarOffset, -ctLowerBound.value());
824+
825+
LayoutAttr outputLayout = LayoutAttr::getFromIntegerRelation(
826+
op.getContext(), zeroIndexedSliceLayout);
827+
assignedLayouts.insert({op.getResult(), outputLayout});
828+
debugAssignLayout(op.getResult(), outputLayout);
829+
setResultLayoutAttr(op);
830+
return success();
831+
}
832+
793833
CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts(
794834
Operation* op) {
795835
return TypeSwitch<Operation*, CompatibilityResult>(op)
@@ -917,7 +957,7 @@ CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts(
917957
tensor::InsertSliceOp op) {
918958
// The arguments of a tensor::InsertSliceOp are the tensors to insert and the
919959
// tensor to insert into.
920-
auto insert = op.getOperands()[0];
960+
auto insert = op.getSource();
921961
auto dest = op.getDest();
922962

923963
if (!assignedLayouts.contains(insert)) {

lib/Utils/Layout/Utils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,5 +747,53 @@ presburger::IntegerRelation shiftVar(
747747
return *shiftedRelation;
748748
}
749749

750+
FailureOr<presburger::IntegerRelation> getSliceExtractionRelation(
751+
RankedTensorType sourceType, RankedTensorType resultType,
752+
SmallVector<int64_t> offsets, SmallVector<int64_t> sizes,
753+
SmallVector<int64_t> strides) {
754+
IntegerRelation result(PresburgerSpace::getRelationSpace(
755+
sourceType.getRank(), /*numRange=*/resultType.getRank(), /*numSymbol=*/0,
756+
/*numLocals=*/0));
757+
758+
// Add bounds for the source dimensions.
759+
auto domainOffset = result.getVarKindOffset(VarKind::Domain);
760+
for (int i = 0; i < sourceType.getRank(); ++i) {
761+
addBounds(result, domainOffset + i, 0, sourceType.getDimSize(i) - 1);
762+
}
763+
764+
// Add bounds for the result dimensions.
765+
auto rangeOffset = result.getVarKindOffset(VarKind::Range);
766+
for (int i = 0; i < resultType.getRank(); ++i) {
767+
addBounds(result, rangeOffset + i, 0, resultType.getDimSize(i) - 1);
768+
}
769+
770+
// Destination tensor's dimensions (d0, d1, ...) are mapped sequentially from
771+
// the source tensor's dimensions (r0, r1, ...) for which the slice size is
772+
// greater than 1.
773+
auto constOffset = result.getNumCols() - 1;
774+
unsigned int resultDim = 0;
775+
for (auto sourceDim = 0; sourceDim < sourceType.getRank(); ++sourceDim) {
776+
if (sizes[sourceDim] > 1) {
777+
// Map to the i-th result dimension
778+
// d_j = offsets[j] + r_i * strides[j]
779+
addConstraint(result,
780+
{{domainOffset + sourceDim, -1},
781+
{constOffset, offsets[sourceDim]},
782+
{rangeOffset + resultDim, strides[sourceDim]}},
783+
/*equality=*/true);
784+
++resultDim;
785+
} else {
786+
// This is a dropped dimension, fixed at the offset
787+
// d_j = offsets[j]
788+
addConstraint(
789+
result,
790+
{{domainOffset + sourceDim, -1}, {constOffset, offsets[sourceDim]}},
791+
/*equality=*/true);
792+
}
793+
}
794+
795+
return result;
796+
}
797+
750798
} // namespace heir
751799
} // namespace mlir

lib/Utils/Layout/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ presburger::IntegerRelation shiftVar(
206206
const presburger::IntegerRelation& relation, unsigned int pos,
207207
int64_t offset);
208208

209+
// Get layout relation that corresponds to a tensor::extract_slice op.
210+
FailureOr<presburger::IntegerRelation> getSliceExtractionRelation(
211+
RankedTensorType sourceType, RankedTensorType resultType,
212+
SmallVector<int64_t> offsets, SmallVector<int64_t> sizes,
213+
SmallVector<int64_t> strides);
214+
209215
} // namespace heir
210216
} // namespace mlir
211217

lib/Utils/Layout/UtilsTest.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,34 @@ TEST(UtilsTest, TestShiftVarRangeOffset) {
519519
EXPECT_TRUE(shiftedRel.containsPointNoLocal({8, 1, 19}).has_value());
520520
}
521521

522+
TEST(UtilsTest, TestGetSliceExtractionRelation) {
523+
MLIRContext context;
524+
// Extract a 3x4 slice from a 2x1x3x4 matrix at (1, 0, 0, 0).
525+
RankedTensorType sourceType =
526+
RankedTensorType::get({2, 1, 3, 4}, IndexType::get(&context));
527+
RankedTensorType sliceType =
528+
RankedTensorType::get({3, 4}, IndexType::get(&context));
529+
SmallVector<int64_t> offsets = {1, 0, 0, 0};
530+
SmallVector<int64_t> sizes = {1, 1, 3, 4};
531+
SmallVector<int64_t> strides = {1, 1, 1, 1};
532+
533+
auto sliceRelation = getSliceExtractionRelation(sourceType, sliceType,
534+
offsets, sizes, strides);
535+
ASSERT_TRUE(succeeded(sliceRelation));
536+
537+
// Test a few points.
538+
// The relation maps from source indices to slice indices.
539+
// For example, source (1,0,0,0) maps to slice (0,0)
540+
std::vector<std::vector<int64_t>> expectedPoints = {
541+
{1, 0, 0, 0, 0, 0}, {1, 0, 0, 1, 0, 1}, {1, 0, 1, 0, 1, 0},
542+
{1, 0, 1, 1, 1, 1}, {1, 0, 2, 2, 2, 2},
543+
};
544+
for (const auto& point : expectedPoints) {
545+
auto maybeExists = sliceRelation.value().containsPointNoLocal(point);
546+
EXPECT_TRUE(maybeExists.has_value());
547+
}
548+
}
549+
522550
} // namespace
523551
} // namespace heir
524552
} // namespace mlir
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: heir-opt --convert-to-ciphertext-semantics=ciphertext-size=32 --split-input-file %s | FileCheck %s
2+
3+
#layout1 = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = i1 and (-4i2 - i3 + slot) mod 16 = 0 and 0 <= i1 <= 1 and 0 <= i2 <= 3 and 0 <= i3 <= 3 and 0 <= slot <= 31 }">
4+
#layout = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : ct = 0 and (-4i0 - i1 + slot) mod 16 = 0 and 0 <= i0 <= 3 and 0 <= i1 <= 3 and 0 <= slot <= 31 }">
5+
module {
6+
// Layouts are aligned perfectly so that extract_slice extracts a single ciphertext out of %input0
7+
// CHECK: func.func @trivial_insert
8+
// CHECK-SAME: (%[[arg0:.*]]: !secret.secret<tensor<2x32xf32>>
9+
func.func @trivial_insert(%arg0: !secret.secret<tensor<1x2x4x4xf32>> {tensor_ext.layout = #layout1}) -> (!secret.secret<tensor<4x4xf32>> {tensor_ext.layout = #layout}) {
10+
%1 = secret.generic(%arg0: !secret.secret<tensor<1x2x4x4xf32>> {tensor_ext.layout = #layout1}) {
11+
^body(%input0: tensor<1x2x4x4xf32>):
12+
// CHECK: secret.generic(%[[arg0]]: !secret.secret<tensor<2x32xf32>>)
13+
// CHECK-NEXT: ^body(%[[input0:.*]]: tensor<2x32xf32>)
14+
// CHECK: %[[v1:.*]] = tensor_ext.remap %[[input0]]
15+
// CHECK-NEXT: %[[extracted:.*]] = tensor.extract_slice %[[v1]][0, 0] [1, 32] [1, 1]
16+
// CHECK-NEXT: %[[v2:.*]] = arith.addf %[[extracted]], %[[extracted]]
17+
// CHECK-NEXT: secret.yield %[[v2]]
18+
%extract_slice = tensor.extract_slice %input0 [0, 1, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] {tensor_ext.layout = #layout}
19+
: tensor<1x2x4x4xf32> to tensor<4x4xf32>
20+
%3 = arith.addf %extract_slice, %extract_slice {tensor_ext.layout = #layout} : tensor<4x4xf32>
21+
secret.yield %3 : tensor<4x4xf32>
22+
} -> (!secret.secret<tensor<4x4xf32>> {tensor_ext.layout = #layout})
23+
return %1 : !secret.secret<tensor<4x4xf32>>
24+
}
25+
}

0 commit comments

Comments
 (0)