Skip to content

Commit b7163fd

Browse files
CKKS: Add RangeAnalysis
1 parent ca9322a commit b7163fd

File tree

8 files changed

+413
-4
lines changed

8 files changed

+413
-4
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package(
2+
default_applicable_licenses = ["@heir//:license"],
3+
default_visibility = ["//visibility:public"],
4+
)
5+
6+
cc_library(
7+
name = "CKKSRangeAnalysis",
8+
srcs = ["CKKSRangeAnalysis.cpp"],
9+
hdrs = ["CKKSRangeAnalysis.h"],
10+
deps = [
11+
"@heir//lib/Analysis:Utils",
12+
"@heir//lib/Analysis/SecretnessAnalysis",
13+
"@heir//lib/Dialect/Mgmt/IR:Dialect",
14+
"@heir//lib/Dialect/Secret/IR:Dialect",
15+
"@heir//lib/Utils",
16+
"@heir//lib/Utils:LogArithmetic",
17+
"@llvm-project//llvm:Support",
18+
"@llvm-project//mlir:Analysis",
19+
"@llvm-project//mlir:CallOpInterfaces",
20+
"@llvm-project//mlir:IR",
21+
"@llvm-project//mlir:Support",
22+
],
23+
)
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#include "lib/Analysis/CKKSRangeAnalysis/CKKSRangeAnalysis.h"
2+
3+
#include <algorithm>
4+
#include <cassert>
5+
#include <functional>
6+
#include <optional>
7+
8+
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
9+
#include "lib/Analysis/Utils.h"
10+
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
11+
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
12+
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
13+
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
14+
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
15+
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
16+
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
18+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
19+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
20+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
21+
22+
#define DEBUG_TYPE "CKKSRangeAnalysis"
23+
24+
namespace mlir {
25+
namespace heir {
26+
27+
//===----------------------------------------------------------------------===//
28+
// CKKSRangeAnalysis (Forward)
29+
//===----------------------------------------------------------------------===//
30+
31+
LogicalResult CKKSRangeAnalysis::visitOperation(
32+
Operation *op, ArrayRef<const CKKSRangeLattice *> operands,
33+
ArrayRef<CKKSRangeLattice *> results) {
34+
auto propagate = [&](Value value, const CKKSRangeState &state) {
35+
auto *lattice = getLatticeElement(value);
36+
LLVM_DEBUG(llvm::dbgs() << "Propagate CKKSRangeState to " << value << ": "
37+
<< state << "\n");
38+
ChangeResult changed = lattice->join(state);
39+
propagateIfChanged(lattice, changed);
40+
};
41+
42+
llvm::TypeSwitch<Operation &>(*op)
43+
.Case<arith::AddFOp, arith::AddIOp>([&](auto &op) {
44+
// condition on result secretness
45+
SmallVector<OpResult> secretResults;
46+
getSecretResults(op, secretResults);
47+
if (secretResults.empty()) {
48+
return;
49+
}
50+
51+
auto rangeResult = Log2Arithmetic::of(0);
52+
53+
for (auto &operand : op->getOpOperands()) {
54+
auto &rangeState = getLatticeElement(operand.get())->getValue();
55+
if (!rangeState.isInitialized()) {
56+
return;
57+
}
58+
rangeResult = rangeResult + rangeState.getCKKSRange();
59+
}
60+
61+
for (auto result : secretResults) {
62+
propagate(result, CKKSRangeState(rangeResult));
63+
}
64+
})
65+
.Case<arith::MulFOp, arith::MulIOp>([&](auto &op) {
66+
// condition on result secretness
67+
SmallVector<OpResult> secretResults;
68+
getSecretResults(op, secretResults);
69+
if (secretResults.empty()) {
70+
return;
71+
}
72+
73+
auto rangeResult = Log2Arithmetic::of(1);
74+
75+
for (auto &operand : op->getOpOperands()) {
76+
auto &rangeState = getLatticeElement(operand.get())->getValue();
77+
if (!rangeState.isInitialized()) {
78+
return;
79+
}
80+
rangeResult = rangeResult * rangeState.getCKKSRange();
81+
}
82+
83+
for (auto result : secretResults) {
84+
propagate(result, CKKSRangeState(rangeResult));
85+
}
86+
})
87+
.Case<arith::ConstantOp>([&](auto &op) {
88+
// For constant, the range is [constant]
89+
std::optional<Log2Arithmetic> range = std::nullopt;
90+
TypedAttr constAttr = op.getValue();
91+
llvm::TypeSwitch<Attribute>(constAttr)
92+
.Case<FloatAttr>([&](FloatAttr value) {
93+
range = Log2Arithmetic::of(std::fabs(value.getValueAsDouble()));
94+
})
95+
.template Case<IntegerAttr>([&](IntegerAttr value) {
96+
range =
97+
Log2Arithmetic::of(std::abs(value.getValue().getSExtValue()));
98+
})
99+
.template Case<DenseElementsAttr>([&](DenseElementsAttr denseAttr) {
100+
auto elementType = getElementTypeOrSelf(constAttr.getType());
101+
if (mlir::isa<FloatType>(elementType)) {
102+
std::optional<APFloat> maxValue;
103+
for (APFloat value : denseAttr.template getValues<APFloat>()) {
104+
value.clearSign();
105+
if (!maxValue.has_value() ||
106+
maxValue->compare(value) == APFloat::cmpLessThan) {
107+
maxValue = value;
108+
}
109+
}
110+
if (maxValue.has_value()) {
111+
range =
112+
Log2Arithmetic::of(maxValue.value().convertToDouble());
113+
}
114+
} else if (mlir::isa<IntegerType>(elementType)) {
115+
std::optional<APInt> maxValue;
116+
for (APInt value : mlir::cast<DenseElementsAttr>(constAttr)
117+
.template getValues<APInt>()) {
118+
value.clearSignBit();
119+
if (!maxValue.has_value() || maxValue->ule(value)) {
120+
maxValue = value;
121+
}
122+
}
123+
range = Log2Arithmetic::of(maxValue.value().getSExtValue());
124+
}
125+
});
126+
// We can encounter DenseResourceElementsAttr, we do not know its range
127+
if (!range.has_value()) {
128+
return;
129+
}
130+
for (auto result : op->getResults()) {
131+
propagate(result, CKKSRangeState(range.value()));
132+
}
133+
})
134+
.Case<mgmt::InitOp>([&](auto &op) {
135+
auto inputState = getLatticeElement(op->getOperand(0))->getValue();
136+
if (!inputState.isInitialized()) {
137+
return;
138+
}
139+
// For InitOp, the range is the same as the input range
140+
propagate(op->getResult(0), inputState);
141+
})
142+
.Case<tensor::InsertOp>([&](tensor::InsertOp op) {
143+
auto scalarState = getLatticeElement(op.getScalar())->getValue();
144+
auto destState = getLatticeElement(op.getDest())->getValue();
145+
if (!scalarState.isInitialized() || !destState.isInitialized()) {
146+
return;
147+
}
148+
auto resultState =
149+
CKKSRangeState::join(scalarState, destState); // Join the ranges
150+
propagate(op.getResult(), resultState);
151+
})
152+
// Rotation does not change the CKKS range
153+
.Default([&](auto &op) {
154+
// condition on result secretness
155+
SmallVector<OpResult> secretResults;
156+
getSecretResults(&op, secretResults);
157+
if (secretResults.empty()) {
158+
return;
159+
}
160+
161+
SmallVector<OpOperand *> secretOperands;
162+
getSecretOperands(&op, secretOperands);
163+
if (secretOperands.empty()) {
164+
return;
165+
}
166+
167+
// short-circuit to get range
168+
CKKSRangeState rangeState;
169+
for (auto *operand : secretOperands) {
170+
auto &operandRangeState =
171+
getLatticeElement(operand->get())->getValue();
172+
if (operandRangeState.isInitialized()) {
173+
rangeState = operandRangeState;
174+
break;
175+
}
176+
}
177+
178+
for (auto result : secretResults) {
179+
propagate(result, rangeState);
180+
}
181+
});
182+
return success();
183+
}
184+
185+
void CKKSRangeAnalysis::visitExternalCall(
186+
CallOpInterface call, ArrayRef<const CKKSRangeLattice *> argumentLattices,
187+
ArrayRef<CKKSRangeLattice *> resultLattices) {
188+
auto callback = std::bind(&CKKSRangeAnalysis::propagateIfChangedWrapper, this,
189+
std::placeholders::_1, std::placeholders::_2);
190+
::mlir::heir::visitExternalCall<CKKSRangeState, CKKSRangeLattice>(
191+
call, argumentLattices, resultLattices, callback);
192+
}
193+
194+
//===----------------------------------------------------------------------===//
195+
// Utils
196+
//===----------------------------------------------------------------------===//
197+
198+
std::optional<CKKSRangeState::CKKSRangeType> getCKKSRange(
199+
Value value, DataFlowSolver *solver) {
200+
auto *lattice = solver->lookupState<CKKSRangeLattice>(value);
201+
if (!lattice) {
202+
return std::nullopt;
203+
}
204+
if (!lattice->getValue().isInitialized()) {
205+
return std::nullopt;
206+
}
207+
return lattice->getValue().getCKKSRange();
208+
}
209+
210+
} // namespace heir
211+
} // namespace mlir
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#ifndef LIB_ANALYSIS_CKKSRANGEANALYSIS_CKKSRANGEANALYSIS_H_
2+
#define LIB_ANALYSIS_CKKSRANGEANALYSIS_CKKSRANGEANALYSIS_H_
3+
4+
#include <algorithm>
5+
#include <cassert>
6+
#include <optional>
7+
8+
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
9+
#include "lib/Utils/LogArithmetic.h"
10+
#include "lib/Utils/Utils.h"
11+
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
12+
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
13+
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
14+
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
15+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
16+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
17+
18+
namespace mlir {
19+
namespace heir {
20+
21+
class CKKSRangeState {
22+
public:
23+
/// Represents a CKKS range using Log2Arithmetic.
24+
/// Here we only store the bound.
25+
/// For double input in range [-1, 1], we use Log2Arithmetic::of(1) to
26+
/// represent it.
27+
using CKKSRangeType = Log2Arithmetic;
28+
29+
CKKSRangeState() : range(std::nullopt) {}
30+
explicit CKKSRangeState(CKKSRangeType range) : range(range) {}
31+
~CKKSRangeState() = default;
32+
33+
CKKSRangeType getCKKSRange() const {
34+
assert(isInitialized());
35+
return range.value();
36+
}
37+
CKKSRangeType get() const { return getCKKSRange(); }
38+
39+
bool operator==(const CKKSRangeState &rhs) const {
40+
return range == rhs.range;
41+
}
42+
43+
bool isInitialized() const { return range.has_value(); }
44+
45+
static CKKSRangeState join(const CKKSRangeState &lhs,
46+
const CKKSRangeState &rhs) {
47+
if (!lhs.isInitialized()) return rhs;
48+
if (!rhs.isInitialized()) return lhs;
49+
50+
return CKKSRangeState{std::max(lhs.getCKKSRange(), rhs.getCKKSRange())};
51+
}
52+
53+
void print(llvm::raw_ostream &os) const {
54+
if (isInitialized()) {
55+
os << "CKKSRangeState(normal: "
56+
<< doubleToString2Prec(range.value().getValue())
57+
<< ", log2: " << doubleToString2Prec(range.value().getLog2Value())
58+
<< ")";
59+
} else {
60+
os << "CKKSRangeState(uninitialized)";
61+
}
62+
}
63+
64+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
65+
const CKKSRangeState &state) {
66+
state.print(os);
67+
return os;
68+
}
69+
70+
private:
71+
std::optional<CKKSRangeType> range;
72+
};
73+
74+
class CKKSRangeLattice : public dataflow::Lattice<CKKSRangeState> {
75+
public:
76+
using Lattice::Lattice;
77+
};
78+
79+
class CKKSRangeAnalysis
80+
: public dataflow::SparseForwardDataFlowAnalysis<CKKSRangeLattice>,
81+
public SecretnessAnalysisDependent<CKKSRangeAnalysis> {
82+
public:
83+
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
84+
friend class SecretnessAnalysisDependent<CKKSRangeAnalysis>;
85+
86+
void setToEntryState(CKKSRangeLattice *lattice) override {
87+
// For double input, default range is [-1, 1]
88+
// This handles both secret input and plaintext func arg
89+
propagateIfChanged(lattice,
90+
lattice->join(CKKSRangeState({Log2Arithmetic::of(1)})));
91+
}
92+
93+
LogicalResult visitOperation(Operation *op,
94+
ArrayRef<const CKKSRangeLattice *> operands,
95+
ArrayRef<CKKSRangeLattice *> results) override;
96+
97+
void visitExternalCall(CallOpInterface call,
98+
ArrayRef<const CKKSRangeLattice *> argumentLattices,
99+
ArrayRef<CKKSRangeLattice *> resultLattices) override;
100+
101+
void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
102+
propagateIfChanged(state, changed);
103+
}
104+
};
105+
106+
std::optional<CKKSRangeState::CKKSRangeType> getCKKSRange(
107+
Value value, DataFlowSolver *solver);
108+
109+
} // namespace heir
110+
} // namespace mlir
111+
112+
#endif // LIB_ANALYSIS_CKKSRANGEANALYSIS_CKKSRANGEANALYSIS_H_

lib/Analysis/DimensionAnalysis/DimensionAnalysis.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ class DimensionAnalysisBackward
126126
void visitCallOperand(OpOperand &operand) override {}
127127
};
128128

129-
// this function will assert false when Lattice does not exist or not
130-
// initialized
131129
std::optional<DimensionState::DimensionType> getDimension(
132130
Value value, DataFlowSolver *solver);
133131

lib/Transforms/GenerateParam/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
hdrs = ["GenerateParam.h"],
1616
deps = [
1717
":pass_inc_gen",
18+
"@heir//lib/Analysis/CKKSRangeAnalysis",
1819
"@heir//lib/Analysis/DimensionAnalysis",
1920
"@heir//lib/Analysis/LevelAnalysis",
2021
"@heir//lib/Analysis/NoiseAnalysis",

0 commit comments

Comments
 (0)