Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ class DimensionAnalysisBackward
void visitCallOperand(OpOperand &operand) override {}
};

// this function will assert false when Lattice does not exist or not
// initialized
std::optional<DimensionState::DimensionType> getDimension(
Value value, DataFlowSolver *solver);

Expand Down
23 changes: 23 additions & 0 deletions lib/Analysis/RangeAnalysis/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "RangeAnalysis",
srcs = ["RangeAnalysis.cpp"],
hdrs = ["RangeAnalysis.h"],
deps = [
"@heir//lib/Analysis:Utils",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:LogArithmetic",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
211 changes: 211 additions & 0 deletions lib/Analysis/RangeAnalysis/RangeAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#include "lib/Analysis/RangeAnalysis/RangeAnalysis.h"

#include <algorithm>
#include <cassert>
#include <functional>
#include <optional>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

#define DEBUG_TYPE "RangeAnalysis"

namespace mlir {
namespace heir {

//===----------------------------------------------------------------------===//
// RangeAnalysis (Forward)
//===----------------------------------------------------------------------===//

LogicalResult RangeAnalysis::visitOperation(
Operation *op, ArrayRef<const RangeLattice *> operands,
ArrayRef<RangeLattice *> results) {
auto propagate = [&](Value value, const RangeState &state) {
auto *lattice = getLatticeElement(value);
LLVM_DEBUG(llvm::dbgs()
<< "Propagate RangeState to " << value << ": " << state << "\n");
ChangeResult changed = lattice->join(state);
propagateIfChanged(lattice, changed);
};

llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddFOp, arith::AddIOp>([&](auto &op) {
// condition on result secretness
SmallVector<OpResult> secretResults;
getSecretResults(op, secretResults);
if (secretResults.empty()) {
return;
}

auto rangeResult = Log2Arithmetic::of(0);

for (auto &operand : op->getOpOperands()) {
auto &rangeState = getLatticeElement(operand.get())->getValue();
if (!rangeState.isInitialized()) {
return;
}
rangeResult = rangeResult + rangeState.getRange();
}

for (auto result : secretResults) {
propagate(result, RangeState(rangeResult));
}
})
.Case<arith::MulFOp, arith::MulIOp>([&](auto &op) {
// condition on result secretness
SmallVector<OpResult> secretResults;
getSecretResults(op, secretResults);
if (secretResults.empty()) {
return;
}

auto rangeResult = Log2Arithmetic::of(1);

for (auto &operand : op->getOpOperands()) {
auto &rangeState = getLatticeElement(operand.get())->getValue();
if (!rangeState.isInitialized()) {
return;
}
rangeResult = rangeResult * rangeState.getRange();
}

for (auto result : secretResults) {
propagate(result, RangeState(rangeResult));
}
})
.Case<arith::ConstantOp>([&](auto &op) {
// For constant, the range is [constant]
std::optional<Log2Arithmetic> range = std::nullopt;
TypedAttr constAttr = op.getValue();
llvm::TypeSwitch<Attribute>(constAttr)
.Case<FloatAttr>([&](FloatAttr value) {
range = Log2Arithmetic::of(std::fabs(value.getValueAsDouble()));
})
.template Case<IntegerAttr>([&](IntegerAttr value) {
range =
Log2Arithmetic::of(std::abs(value.getValue().getSExtValue()));
})
.template Case<DenseElementsAttr>([&](DenseElementsAttr denseAttr) {
auto elementType = getElementTypeOrSelf(constAttr.getType());
if (mlir::isa<FloatType>(elementType)) {
std::optional<APFloat> maxValue;
for (APFloat value : denseAttr.template getValues<APFloat>()) {
value.clearSign();
if (!maxValue.has_value() ||
maxValue->compare(value) == APFloat::cmpLessThan) {
maxValue = value;
}
}
if (maxValue.has_value()) {
range =
Log2Arithmetic::of(maxValue.value().convertToDouble());
}
} else if (mlir::isa<IntegerType>(elementType)) {
std::optional<APInt> maxValue;
for (APInt value : mlir::cast<DenseElementsAttr>(constAttr)
.template getValues<APInt>()) {
value.clearSignBit();
if (!maxValue.has_value() || maxValue->ule(value)) {
maxValue = value;
}
}
range = Log2Arithmetic::of(maxValue.value().getSExtValue());
}
});
// We can encounter DenseResourceElementsAttr, we do not know its range
if (!range.has_value()) {
return;
}
for (auto result : op->getResults()) {
propagate(result, RangeState(range.value()));
}
})
.Case<mgmt::InitOp>([&](auto &op) {
auto inputState = getLatticeElement(op->getOperand(0))->getValue();
if (!inputState.isInitialized()) {
return;
}
// For InitOp, the range is the same as the input range
propagate(op->getResult(0), inputState);
})
.Case<tensor::InsertOp>([&](tensor::InsertOp op) {
auto scalarState = getLatticeElement(op.getScalar())->getValue();
auto destState = getLatticeElement(op.getDest())->getValue();
if (!scalarState.isInitialized() || !destState.isInitialized()) {
return;
}
auto resultState =
RangeState::join(scalarState, destState); // Join the ranges
propagate(op.getResult(), resultState);
})
// Rotation does not change the CKKS range
.Default([&](auto &op) {
// condition on result secretness
SmallVector<OpResult> secretResults;
getSecretResults(&op, secretResults);
if (secretResults.empty()) {
return;
}

SmallVector<OpOperand *> secretOperands;
getSecretOperands(&op, secretOperands);
if (secretOperands.empty()) {
return;
}

// short-circuit to get range
RangeState rangeState;
for (auto *operand : secretOperands) {
auto &operandRangeState =
getLatticeElement(operand->get())->getValue();
if (operandRangeState.isInitialized()) {
rangeState = operandRangeState;
break;
}
}

for (auto result : secretResults) {
propagate(result, rangeState);
}
});
return success();
}

void RangeAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const RangeLattice *> argumentLattices,
ArrayRef<RangeLattice *> resultLattices) {
auto callback = std::bind(&RangeAnalysis::propagateIfChangedWrapper, this,
std::placeholders::_1, std::placeholders::_2);
::mlir::heir::visitExternalCall<RangeState, RangeLattice>(
call, argumentLattices, resultLattices, callback);
}

//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//

std::optional<RangeState::RangeType> getRange(Value value,
DataFlowSolver *solver) {
auto *lattice = solver->lookupState<RangeLattice>(value);
if (!lattice) {
return std::nullopt;
}
if (!lattice->getValue().isInitialized()) {
return std::nullopt;
}
return lattice->getValue().getRange();
}

} // namespace heir
} // namespace mlir
113 changes: 113 additions & 0 deletions lib/Analysis/RangeAnalysis/RangeAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#ifndef LIB_ANALYSIS_CKKSRANGEANALYSIS_CKKSRANGEANALYSIS_H_
#define LIB_ANALYSIS_CKKSRANGEANALYSIS_CKKSRANGEANALYSIS_H_

#include <algorithm>
#include <cassert>
#include <optional>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Utils/LogArithmetic.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

class RangeState {
public:
/// Represents a range using Log2Arithmetic.
/// Here we only store the bound.
/// For [a, b], store Log2Arithmetic::of(max(abs(a), abs(b))).
using RangeType = Log2Arithmetic;

RangeState() : range(std::nullopt) {}
explicit RangeState(RangeType range) : range(range) {}
~RangeState() = default;

RangeType getRange() const {
assert(isInitialized());
return range.value();
}
RangeType get() const { return getRange(); }

bool operator==(const RangeState &rhs) const { return range == rhs.range; }

bool isInitialized() const { return range.has_value(); }

static RangeState join(const RangeState &lhs, const RangeState &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;

return RangeState{std::max(lhs.getRange(), rhs.getRange())};
}

void print(llvm::raw_ostream &os) const {
if (isInitialized()) {
os << "RangeState(normal: "
<< doubleToString2Prec(range.value().getValue())
<< ", log2: " << doubleToString2Prec(range.value().getLog2Value())
<< ")";
} else {
os << "RangeState(uninitialized)";
}
}

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const RangeState &state) {
state.print(os);
return os;
}

private:
std::optional<RangeType> range;
};

class RangeLattice : public dataflow::Lattice<RangeState> {
public:
using Lattice::Lattice;
};

class RangeAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<RangeLattice>,
public SecretnessAnalysisDependent<RangeAnalysis> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
friend class SecretnessAnalysisDependent<RangeAnalysis>;

RangeAnalysis(DataFlowSolver &solver, Log2Arithmetic inputRange)
: dataflow::SparseForwardDataFlowAnalysis<RangeLattice>(solver),
inputRange(inputRange) {}

void setToEntryState(RangeLattice *lattice) override {
// This handles both secret input and plaintext func arg
propagateIfChanged(lattice, lattice->join(RangeState({inputRange})));
}

LogicalResult visitOperation(Operation *op,
ArrayRef<const RangeLattice *> operands,
ArrayRef<RangeLattice *> results) override;

void visitExternalCall(CallOpInterface call,
ArrayRef<const RangeLattice *> argumentLattices,
ArrayRef<RangeLattice *> resultLattices) override;

void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) {
propagateIfChanged(state, changed);
}

private:
Log2Arithmetic inputRange;
};

std::optional<RangeState::RangeType> getRange(Value value,
DataFlowSolver *solver);

} // namespace heir
} // namespace mlir

#endif // LIB_ANALYSIS_CKKSRANGEANALYSIS_CKKSRANGEANALYSIS_H_
1 change: 1 addition & 0 deletions lib/Transforms/GenerateParam/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cc_library(
"@heir//lib/Analysis/NoiseAnalysis/BGV:NoiseByBoundCoeffModel",
"@heir//lib/Analysis/NoiseAnalysis/BGV:NoiseByVarianceCoeffModel",
"@heir//lib/Analysis/NoiseAnalysis/BGV:NoiseCanEmbModel",
"@heir//lib/Analysis/RangeAnalysis",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect:ModuleAttributes",
"@heir//lib/Dialect/BGV/IR:Dialect",
Expand Down
3 changes: 3 additions & 0 deletions lib/Transforms/GenerateParam/GenerateParam.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def GenerateParamCKKS : Pass<"generate-param-ckks"> {
"If true, uses a public key for encryption.">,
Option<"encryptionTechniqueExtended", "encryption-technique-extended", "bool", /*default=*/"false",
"If true, uses EXTENDED encryption technique for encryption. (See https://ia.cr/2022/915)">,
Option<"inputRange", "input-range", "int",
/*default=*/"1", "The range of the plaintexts for input ciphertexts "
"for the CKKS scheme; default to [-1, 1]. For other ranges like [-D, D], use D.">,
];
}

Expand Down
Loading
Loading