diff --git a/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.cpp b/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.cpp index 63a827cfb7..78e0b1566a 100644 --- a/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.cpp +++ b/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.cpp @@ -1,8 +1,10 @@ #include "lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h" +#include #include #include #include +#include #include #include #include @@ -31,14 +33,18 @@ // while ensuring level constraints are satisfied. // // ILP Formulation: -// - Variables: level[value] for each SSA value, input_level[op] for each op, -// bootstrap[op] for each operation +// - Variables: level[value] and scale[value] for each SSA value, +// input_level[op] and input_scale[op] for each op, bootstrap[op] for each +// operation // - Constraints: // * Level bounds: 0 <= level[value] <= bootstrapWaterline (levels // 0..bootstrapWaterline) // * Operand matching: for each op, operand_i_level == input_level[op] // * Multiplication: output_level = input_level - 1 // * Non-mult: output_level = input_level +// * Scale constraints: mul input scale is the sum of operand scales, +// rescale decisions lower scale by scaleFactorBits, and bootstrap +// decisions require feasible level/scale input pairs. // * Bootstrap (big-M): if bootstrap[op] = 1, level_after = bootstrapWaterline // else level_after = level_before // - Objective: minimize sum of bootstrap decisions @@ -89,14 +95,26 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { llvm::DenseMap levelVars; // Map SSA values to level variables before bootstrap llvm::DenseMap beforeBootstrapVars; + // Map SSA values to scale variables (after bootstrap decision) + llvm::DenseMap scaleVars; + // Map SSA values to scale variables before bootstrap/output reduction + llvm::DenseMap beforeBootstrapScaleVars; // Input level: level at which operands are consumed llvm::DenseMap inputLevelVars; + // Input scale: scale at which operands are consumed. + llvm::DenseMap inputScaleVars; + llvm::DenseMap operandDropVars; + llvm::DenseMap outputDropVars; + SmallVector trackedOps; // Big-M constant for big-M method - // NOTE: This BIG_M does not account for "freshly encrypted" ciphertexts + // NOTE: This bigM does not account for "freshly encrypted" ciphertexts // starting at a higher level than the bootstrap waterline. This should // be addressed in future work. - const int BIG_M = bootstrapWaterline; + const int bigM = bootstrapWaterline; + const int scaleMax = scaleFactorBits + 2 * scaleWaterline; + const int inputScaleMax = 2 * scaleMax; + const int scaleBigM = scaleMax + scaleFactorBits * bootstrapWaterline; // Create variables for all SSA values in the body // First, handle block arguments (inputs) @@ -105,9 +123,13 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { std::stringstream ss; ss << "levelArg" << arg.getArgNumber(); - auto levelVar = - model.AddContinuousVariable(0, bootstrapWaterline, ss.str()); + auto levelVar = model.AddIntegerVariable(0, bootstrapWaterline, ss.str()); levelVars.insert(std::make_pair(arg, levelVar)); + std::stringstream ssScale; + ssScale << "scaleArg" << arg.getArgNumber(); + auto scaleVar = + model.AddContinuousVariable(scaleWaterline, scaleMax, ssScale.str()); + scaleVars.insert(std::make_pair(arg, scaleVar)); // Inputs start at maximum level (bootstrapWaterline) model.AddLinearConstraint( levelVar == bootstrapWaterline, @@ -124,6 +146,7 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { } if (isa(op)) continue; + trackedOps.push_back(&op); std::string opName = uniqueName(&op); // Create bootstrap decision variable for this operation @@ -131,9 +154,12 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { decisionVariables.insert(std::make_pair(&op, bootstrapVar)); // Create input level variable: level at which operands are consumed - auto inputLevelVar = model.AddContinuousVariable(0, bootstrapWaterline, - "inputLevel" + opName); + auto inputLevelVar = + model.AddIntegerVariable(0, bootstrapWaterline, "inputLevel" + opName); inputLevelVars.insert(std::make_pair(&op, inputLevelVar)); + auto inputScaleVar = model.AddContinuousVariable( + scaleWaterline, inputScaleMax, "inputScale" + opName); + inputScaleVars.insert(std::make_pair(&op, inputScaleVar)); // Create level variables for results for (OpResult result : op.getResults()) { @@ -141,15 +167,32 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { std::stringstream ss; ss << "level" << opName << result.getResultNumber(); - auto levelVar = - model.AddContinuousVariable(0, bootstrapWaterline, ss.str()); + auto levelVar = model.AddIntegerVariable(0, bootstrapWaterline, ss.str()); levelVars.insert(std::make_pair(result, levelVar)); std::stringstream ss2; ss2 << "levelBefore" << opName << result.getResultNumber(); auto beforeVar = - model.AddContinuousVariable(0, bootstrapWaterline, ss2.str()); + model.AddIntegerVariable(0, bootstrapWaterline, ss2.str()); beforeBootstrapVars.insert(std::make_pair(result, beforeVar)); + + std::stringstream ssScale; + ssScale << "scale" << opName << result.getResultNumber(); + auto scaleVar = + model.AddContinuousVariable(scaleWaterline, scaleMax, ssScale.str()); + scaleVars.insert(std::make_pair(result, scaleVar)); + + std::stringstream ssBeforeScale; + ssBeforeScale << "scaleBefore" << opName << result.getResultNumber(); + auto beforeScaleVar = model.AddContinuousVariable( + scaleWaterline, scaleMax, ssBeforeScale.str()); + beforeBootstrapScaleVars.insert(std::make_pair(result, beforeScaleVar)); + + std::stringstream ss3; + ss3 << "outputDrop" << opName << result.getResultNumber(); + auto outputDropVar = + model.AddIntegerVariable(0, bootstrapWaterline, ss3.str()); + outputDropVars.insert(std::make_pair(result, outputDropVar)); } } @@ -157,21 +200,64 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { for (auto& [op, _] : opaqueIds) { std::string opName = uniqueName(op); - // Get secret operands - SmallVector secretOperands; - getSecretOperands(op, solver, secretOperands); - - // Operand matching: all operands must be at the same level when consumed. - // inputLevel = level at which operands are consumed. + // Operand matching: all operands must be consumed at the same level. + // Unlike the original model, a producer may stay at a higher level and a + // specific edge/use can level-reduce into this operation. auto inputLevelVar = inputLevelVars.at(op); - for (size_t i = 0; i < secretOperands.size(); ++i) { - Value operand = secretOperands[i]; + auto inputScaleVar = inputScaleVars.at(op); + SmallVector secretOperandScaleVars; + for (OpOperand& operandUse : op->getOpOperands()) { + Value operand = operandUse.get(); + if (!isSecret(operand, solver)) continue; if (!levelVars.contains(operand)) continue; std::stringstream ss; - ss << "operandMatch" << opName << "Op" << i; - model.AddLinearConstraint(levelVars.at(operand) == inputLevelVar, - ss.str()); + ss << "operandLevel" << opName << "Op" << operandUse.getOperandNumber(); + auto operandLevelVar = + model.AddIntegerVariable(0, bootstrapWaterline, ss.str()); + std::stringstream ssDrop; + ssDrop << "operandDrop" << opName << "Op" + << operandUse.getOperandNumber(); + auto dropVar = + model.AddIntegerVariable(0, bootstrapWaterline, ssDrop.str()); + operandDropVars.insert(std::make_pair(&operandUse, dropVar)); + + std::stringstream ssScale; + ssScale << "operandScale" << opName << "Op" + << operandUse.getOperandNumber(); + auto operandScaleVar = + model.AddContinuousVariable(scaleWaterline, scaleMax, ssScale.str()); + secretOperandScaleVars.push_back(operandScaleVar); + + model.AddLinearConstraint(operandLevelVar == inputLevelVar, + ss.str() + "MatchInput"); + model.AddLinearConstraint( + levelVars.at(operand) == operandLevelVar + dropVar, + ssDrop.str() + "FromSource"); + model.AddLinearConstraint( + operandScaleVar >= scaleVars.at(operand) - scaleFactorBits * dropVar, + ssScale.str() + "FromSource"); + model.AddLinearConstraint( + operandScaleVar <= scaleVars.at(operand) - scaleFactorBits * dropVar, + ssScale.str() + "FromSourceUpper"); + } + + if (isa(op) || isa(op)) { + math_opt::LinearExpression mulInputScale; + for (auto operandScaleVar : secretOperandScaleVars) { + mulInputScale += operandScaleVar; + } + if (!secretOperandScaleVars.empty()) { + model.AddLinearConstraint(inputScaleVar == mulInputScale, + "mulInputScale" + opName); + } + } else { + for (auto [i, operandScaleVar] : + llvm::enumerate(secretOperandScaleVars)) { + model.AddLinearConstraint( + operandScaleVar == inputScaleVar, + "flowInputScale" + opName + std::to_string(i)); + } } // Output level: for multiplication, output = input - 1; else output = input @@ -189,6 +275,14 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { std::stringstream ssMin; ssMin << "mulLevelMin" << opName << result.getResultNumber(); model.AddLinearConstraint(resultBeforeVar >= 0, ssMin.str()); + + auto beforeScaleVar = beforeBootstrapScaleVars.at(result); + std::stringstream ssScale; + ssScale << "mulScaleOutput" << opName << result.getResultNumber(); + // The pass always materializes mgmt.modreduce after multiplication; + // in CKKS this is the ordinary post-mul rescale by scaleFactorBits. + model.AddLinearConstraint( + beforeScaleVar == inputScaleVar - scaleFactorBits, ssScale.str()); } } else { for (OpResult result : op->getResults()) { @@ -200,6 +294,12 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { std::stringstream ss; ss << "flowOutput" << opName << result.getResultNumber(); model.AddLinearConstraint(resultBeforeVar == inputLevelVar, ss.str()); + + auto beforeScaleVar = beforeBootstrapScaleVars.at(result); + std::stringstream ssScale; + ssScale << "flowScaleOutput" << opName << result.getResultNumber(); + model.AddLinearConstraint(beforeScaleVar == inputScaleVar, + ssScale.str()); } } @@ -211,43 +311,208 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { auto resultLevelVar = levelVars.at(result); auto resultBeforeVar = beforeBootstrapVars.at(result); + auto resultScaleVar = scaleVars.at(result); + auto resultBeforeScaleVar = beforeBootstrapScaleVars.at(result); + auto outputDropVar = outputDropVars.at(result); auto bootstrapVar = decisionVariables.at(op); std::stringstream ss; ss << "bootstrapOutput" << opName << result.getResultNumber(); // If bootstrap = 1: level_after = bootstrapWaterline - // If bootstrap = 0: levelAfter = levelBefore + // If bootstrap = 0: levelAfter = levelBefore - outputDrop // Using big-M: - // levelAfter <= bootstrapWaterline + BIG_M * (1 - bootstrap) - // levelAfter >= bootstrapWaterline - BIG_M * (1 - bootstrap) - // levelAfter <= levelBefore + BIG_M * bootstrap - // levelAfter >= levelBefore - BIG_M * bootstrap + // levelAfter <= bootstrapWaterline + bigM * (1 - bootstrap) + // levelAfter >= bootstrapWaterline - bigM * (1 - bootstrap) + // levelAfter <= levelBefore - outputDrop + bigM * bootstrap + // levelAfter >= levelBefore - outputDrop - bigM * bootstrap std::string cstName1 = ss.str() + "_1"; model.AddLinearConstraint( - resultLevelVar <= bootstrapWaterline + BIG_M * (1 - bootstrapVar), + resultLevelVar <= bootstrapWaterline + bigM * (1 - bootstrapVar), cstName1); std::string cstName2 = ss.str() + "_2"; model.AddLinearConstraint( - resultLevelVar >= bootstrapWaterline - BIG_M * (1 - bootstrapVar), + resultLevelVar >= bootstrapWaterline - bigM * (1 - bootstrapVar), cstName2); std::string cstName3 = ss.str() + "_3"; - model.AddLinearConstraint( - resultLevelVar <= resultBeforeVar + BIG_M * bootstrapVar, cstName3); + model.AddLinearConstraint(resultLevelVar <= resultBeforeVar - + outputDropVar + + bigM * bootstrapVar, + cstName3); std::string cstName4 = ss.str() + "_4"; + model.AddLinearConstraint(resultLevelVar >= resultBeforeVar - + outputDropVar - + bigM * bootstrapVar, + cstName4); + + // Scale-aware bootstrap feasibility: bootstrapping is only + // valid when the input scale fits the input level, and a bootstrap + // produces a ciphertext with at least the base scale. Without + // bootstrapping, output rescale decisions lower scale by scaleFactorBits + // per dropped level. + std::string scaleCstName1 = ss.str() + "_scale_bts_input"; + model.AddLinearConstraint( + resultBeforeScaleVar <= + scaleFactorBits * + (resultBeforeVar - bootstrapLevelLowerBound + 1) + + scaleBigM * (1 - bootstrapVar), + scaleCstName1); + + std::string scaleCstName2 = ss.str() + "_scale_bts_output"; + model.AddLinearConstraint( + resultScaleVar >= scaleFactorBits - scaleBigM * (1 - bootstrapVar), + scaleCstName2); + + std::string scaleCstName3 = ss.str() + "_scale_nobts_output"; + model.AddLinearConstraint( + resultScaleVar >= resultBeforeScaleVar - + scaleFactorBits * outputDropVar - + scaleBigM * bootstrapVar, + scaleCstName3); + + std::string scaleCstName4 = ss.str() + "_scale_nobts_output_upper"; model.AddLinearConstraint( - resultLevelVar >= resultBeforeVar - BIG_M * bootstrapVar, cstName4); + resultScaleVar <= resultBeforeScaleVar - + scaleFactorBits * outputDropVar + + scaleBigM * bootstrapVar, + scaleCstName4); + } + } + + if (useOrbitCompression) { + llvm::DenseMap opColors; + llvm::DenseMap valueColors; + + for (BlockArgument arg : body->getArguments()) { + if (isSecret(arg, solver)) { + valueColors.insert( + std::make_pair(arg, ("arg" + std::to_string(arg.getArgNumber())))); + } + } + for (Operation* op : trackedOps) { + opColors.insert(std::make_pair(op, op->getName().getStringRef().str())); + for (OpResult result : op->getResults()) { + if (isSecret(result, solver)) valueColors[result] = opColors.lookup(op); + } + } + + auto join = [](SmallVector& parts) { + std::sort(parts.begin(), parts.end()); + std::string out; + llvm::raw_string_ostream os(out); + for (const auto& part : parts) os << part << ";"; + return os.str(); + }; + + int previousGroupCount = -1; + for (int iter = 0; iter < 32; ++iter) { + std::map descriptorToGroup; + llvm::DenseMap nextOpColors; + llvm::DenseMap nextValueColors = valueColors; + + for (Operation* op : trackedOps) { + SmallVector secretOperands; + getSecretOperands(op, solver, secretOperands); + SmallVector operandColors; + for (Value operand : secretOperands) { + auto it = valueColors.find(operand); + operandColors.push_back(it == valueColors.end() ? "external" + : it->second); + } + + SmallVector userColors; + for (OpResult result : op->getResults()) { + if (!isSecret(result, solver)) continue; + for (Operation* user : result.getUsers()) { + auto it = opColors.find(user); + userColors.push_back(it == opColors.end() ? "external" + : it->second); + } + } + + std::string descriptor; + llvm::raw_string_ostream os(descriptor); + os << op->getName().getStringRef() << "|in=" << join(operandColors) + << "|out=" << join(userColors) << "|results="; + for (OpResult result : op->getResults()) { + if (isSecret(result, solver)) os << result.getResultNumber() << ":"; + } + + auto [it, inserted] = descriptorToGroup.insert( + std::make_pair(os.str(), descriptorToGroup.size())); + std::string color = "orbit_group_" + std::to_string(it->second); + nextOpColors[op] = color; + for (OpResult result : op->getResults()) { + if (isSecret(result, solver)) nextValueColors[result] = color; + } + } + + opColors = std::move(nextOpColors); + valueColors = std::move(nextValueColors); + int groupCount = descriptorToGroup.size(); + if (groupCount == previousGroupCount) break; + previousGroupCount = groupCount; + } + + std::map> groups; + for (Operation* op : trackedOps) groups[opColors.lookup(op)].push_back(op); + + for (auto& [_, group] : groups) { + if (group.size() <= 1) continue; + Operation* anchor = group.front(); + for (Operation* op : llvm::drop_begin(group)) { + model.AddLinearConstraint( + decisionVariables.at(op) == decisionVariables.at(anchor), + "orbitBootstrapDecision" + uniqueName(op)); + model.AddLinearConstraint( + inputLevelVars.at(op) == inputLevelVars.at(anchor), + "orbitInputLevel" + uniqueName(op)); + model.AddLinearConstraint( + inputScaleVars.at(op) == inputScaleVars.at(anchor), + "orbitInputScale" + uniqueName(op)); + + for (auto [anchorResult, result] : + llvm::zip_equal(anchor->getResults(), op->getResults())) { + if (!isSecret(anchorResult, solver) || !isSecret(result, solver)) { + continue; + } + model.AddLinearConstraint( + levelVars.at(result) == levelVars.at(anchorResult), + "orbitLevel" + uniqueName(op) + + std::to_string(result.getResultNumber())); + model.AddLinearConstraint( + beforeBootstrapVars.at(result) == + beforeBootstrapVars.at(anchorResult), + "orbitBeforeLevel" + uniqueName(op) + + std::to_string(result.getResultNumber())); + model.AddLinearConstraint( + scaleVars.at(result) == scaleVars.at(anchorResult), + "orbitScale" + uniqueName(op) + + std::to_string(result.getResultNumber())); + model.AddLinearConstraint( + beforeBootstrapScaleVars.at(result) == + beforeBootstrapScaleVars.at(anchorResult), + "orbitBeforeScale" + uniqueName(op) + + std::to_string(result.getResultNumber())); + } + } } } - // Objective: minimize number of bootstraps + // Objective: minimize a weighted placement cost. math_opt::LinearExpression obj; for (auto& [op, decisionVar] : decisionVariables) { - obj += decisionVar; + obj += bootstrapCost * decisionVar; + } + for (auto& [value, dropVar] : outputDropVars) { + obj += rescaleCost * dropVar; + } + for (auto& [operand, dropVar] : operandDropVars) { + obj += rescaleCost * dropVar; } model.Minimize(obj); @@ -295,6 +560,17 @@ LogicalResult ILPBootstrapPlacementAnalysis::solve() { solutionLevelAfterBootstrap.insert( std::make_pair(value, (int)std::round(varMap[levelVar]))); } + for (auto& [operand, dropVar] : operandDropVars) { + int levelToDrop = (int)std::round(varMap[dropVar]); + if (levelToDrop <= 0) continue; + operandLevelReductions.push_back( + {operand->getOwner(), operand->getOperandNumber(), levelToDrop}); + } + for (auto& [value, dropVar] : outputDropVars) { + int levelToDrop = (int)std::round(varMap[dropVar]); + if (levelToDrop <= 0) continue; + outputLevelReductions.push_back({value, levelToDrop}); + } return success(); } diff --git a/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h b/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h index e52c34f186..30c636d10b 100644 --- a/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h +++ b/lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h @@ -16,9 +16,31 @@ namespace mlir { namespace heir { class ILPBootstrapPlacementAnalysis { public: + struct OperandLevelReduction { + Operation* op; + unsigned operandNumber; + int levelToDrop; + }; + + struct OutputLevelReduction { + Value value; + int levelToDrop; + }; + ILPBootstrapPlacementAnalysis(Operation* op, DataFlowSolver* solver, - int bootstrapWaterline) - : opToRunOn(op), solver(solver), bootstrapWaterline(bootstrapWaterline) {} + int bootstrapWaterline, int scaleWaterline, + int scaleFactorBits, + int bootstrapLevelLowerBound, int bootstrapCost, + int rescaleCost, bool useOrbitCompression) + : opToRunOn(op), + solver(solver), + bootstrapWaterline(bootstrapWaterline), + scaleWaterline(scaleWaterline), + scaleFactorBits(scaleFactorBits), + bootstrapLevelLowerBound(bootstrapLevelLowerBound), + bootstrapCost(bootstrapCost), + rescaleCost(rescaleCost), + useOrbitCompression(useOrbitCompression) {} ~ILPBootstrapPlacementAnalysis() = default; LogicalResult solve(); @@ -29,6 +51,17 @@ class ILPBootstrapPlacementAnalysis { // relinearize insertion. llvm::SmallVector getValuesToBootstrap() const; + // Return per-use level reductions chosen by the ILP. + llvm::SmallVector getOperandLevelReductions() + const { + return operandLevelReductions; + } + + // Return per-result level reductions chosen by the ILP. + llvm::SmallVector getOutputLevelReductions() const { + return outputLevelReductions; + } + // Return the level at the given SSA value, as determined by the // solution to the optimization problem. When the input value is the result // of an op, and the model solution suggests a bootstrap should be @@ -47,9 +80,17 @@ class ILPBootstrapPlacementAnalysis { Operation* opToRunOn; DataFlowSolver* solver; int bootstrapWaterline; + int scaleWaterline; + int scaleFactorBits; + int bootstrapLevelLowerBound; + int bootstrapCost; + int rescaleCost; + bool useOrbitCompression; llvm::DenseMap solution; llvm::DenseMap solutionLevelBeforeBootstrap; llvm::DenseMap solutionLevelAfterBootstrap; + llvm::SmallVector operandLevelReductions; + llvm::SmallVector outputLevelReductions; }; } // namespace heir } // namespace mlir diff --git a/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.cpp b/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.cpp index e82cbeaf25..89e92fba9a 100644 --- a/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.cpp +++ b/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.cpp @@ -1,5 +1,10 @@ #include "lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.h" +#include +#include +#include +#include + #include "lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h" #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/Mgmt/IR/MgmtOps.h" @@ -7,6 +12,8 @@ #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Transforms/SecretInsertMgmt/Pipeline.h" #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/JSON.h" // from @llvm-project +#include "llvm/include/llvm/Support/MemoryBuffer.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project @@ -26,20 +33,90 @@ namespace heir { #define GEN_PASS_DEF_ILPBOOTSTRAPPLACEMENT #include "lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.h.inc" +struct OrbitCostModel { + int bootstrapCost; + int rescaleCost; +}; + +static std::optional averagePositiveLatency(const llvm::json::Object& root, + llvm::StringRef opName) { + const llvm::json::Object* latencyTable = root.getObject("latencyTable"); + if (!latencyTable) return std::nullopt; + + const llvm::json::Array* latencies = latencyTable->getArray(opName); + if (!latencies) return std::nullopt; + + double sum = 0; + int count = 0; + for (const llvm::json::Value& latencyValue : *latencies) { + std::optional latency = latencyValue.getAsNumber(); + if (!latency || *latency <= 0) continue; + sum += *latency; + ++count; + } + if (count == 0) return std::nullopt; + return static_cast(std::llround(sum / count)); +} + +static FailureOr loadOrbitCostModel(llvm::StringRef path) { + auto bufferOrError = llvm::MemoryBuffer::getFile(path); + if (!bufferOrError) return failure(); + + llvm::Expected parsed = + llvm::json::parse((*bufferOrError)->getBuffer()); + if (!parsed) { + llvm::consumeError(parsed.takeError()); + return failure(); + } + + const llvm::json::Object* root = parsed->getAsObject(); + if (!root) return failure(); + + std::optional parsedBootstrapCost = + averagePositiveLatency(*root, "earth.bootstrap_single"); + std::optional parsedRescaleCost = + averagePositiveLatency(*root, "earth.rescale_single"); + if (!parsedBootstrapCost || !parsedRescaleCost) return failure(); + + return OrbitCostModel{*parsedBootstrapCost, *parsedRescaleCost}; +} + struct ILPBootstrapPlacement : impl::ILPBootstrapPlacementBase { using ILPBootstrapPlacementBase::ILPBootstrapPlacementBase; LogicalResult processSecretGenericOp( secret::GenericOp genericOp, DataFlowSolver* solver, - SmallVector* valuesToBootstrap) { + SmallVector* valuesToBootstrap, + SmallVector* + outputLevelReductions, + SmallVector* + operandLevelReductions) { genericOp->walk([&](mgmt::BootstrapOp op) { op.getResult().replaceAllUsesWith(op.getOperand()); op.erase(); }); - ILPBootstrapPlacementAnalysis analysis(genericOp, solver, - bootstrapWaterline); + int effectiveBootstrapCost = bootstrapCost; + int effectiveRescaleCost = rescaleCost; + if (!orbitCostModel.empty()) { + FailureOr loadedCostModel = + loadOrbitCostModel(orbitCostModel); + if (failed(loadedCostModel)) { + llvm::errs() << "failed to load Orbit cost model from `" + << orbitCostModel << "`\n"; + genericOp->emitError() << "failed to load Orbit cost model from `" + << orbitCostModel << "`"; + return failure(); + } + effectiveBootstrapCost = loadedCostModel->bootstrapCost; + effectiveRescaleCost = loadedCostModel->rescaleCost; + } + + ILPBootstrapPlacementAnalysis analysis( + genericOp, solver, bootstrapWaterline, scaleWaterline, scaleFactorBits, + bootstrapLevelLowerBound, effectiveBootstrapCost, effectiveRescaleCost, + useOrbitCompression); if (failed(analysis.solve())) { genericOp->emitError( "Failed to solve the bootstrap placement optimization problem"); @@ -48,26 +125,69 @@ struct ILPBootstrapPlacement LLVM_DEBUG(analysis.printSolution(llvm::dbgs())); for (Value v : analysis.getValuesToBootstrap()) valuesToBootstrap->push_back(v); + for (auto reduction : analysis.getOutputLevelReductions()) + outputLevelReductions->push_back(reduction); + for (auto reduction : analysis.getOperandLevelReductions()) + operandLevelReductions->push_back(reduction); return success(); } + std::pair followRelinearizeModReduceChain(Value value) { + Value chainValue = value; + Operation* chainEnd = value.getDefiningOp(); + while (chainValue.hasOneUse()) { + Operation* user = *chainValue.getUsers().begin(); + if (isa(user) || isa(user)) { + chainValue = user->getResult(0); + chainEnd = user; + continue; + } + break; + } + return {chainValue, chainEnd}; + } + + void insertOutputLevelReductions( + ArrayRef + outputLevelReductions) { + OpBuilder b(&getContext()); + for (auto reduction : outputLevelReductions) { + auto [toReduce, insertAfter] = + followRelinearizeModReduceChain(reduction.value); + if (!insertAfter) continue; + + b.setInsertionPointAfter(insertAfter); + auto levelReduceOp = mgmt::LevelReduceOp::create( + b, insertAfter->getLoc(), toReduce, + b.getI64IntegerAttr(reduction.levelToDrop)); + toReduce.replaceAllUsesExcept(levelReduceOp.getResult(), {levelReduceOp}); + } + } + + void insertOperandLevelReductions( + ArrayRef + operandLevelReductions) { + OpBuilder b(&getContext()); + for (auto reduction : operandLevelReductions) { + Operation* op = reduction.op; + if (!op || reduction.operandNumber >= op->getNumOperands()) continue; + + Value operand = op->getOperand(reduction.operandNumber); + b.setInsertionPoint(op); + auto levelReduceOp = mgmt::LevelReduceOp::create( + b, op->getLoc(), operand, b.getI64IntegerAttr(reduction.levelToDrop)); + op->setOperand(reduction.operandNumber, levelReduceOp.getResult()); + } + } + void insertBootstrapsForValues(ArrayRef valuesToBootstrap) { OpBuilder b(&getContext()); for (Value v : valuesToBootstrap) { // After modreduce/relinearize we have mul -> relinearize -> modreduce. // Follow the chain so we bootstrap the modreduce result (correct level // refresh) and insert after it. - Value toBootstrap = v; - Operation* insertAfter = v.getDefiningOp(); - while (toBootstrap.hasOneUse()) { - Operation* user = *toBootstrap.getUsers().begin(); - if (isa(user) || isa(user)) { - toBootstrap = user->getResult(0); - insertAfter = user; - } else { - break; - } - } + auto [toBootstrap, insertAfter] = followRelinearizeModReduceChain(v); + if (!insertAfter) continue; b.setInsertionPointAfter(insertAfter); auto bootstrapOp = mgmt::BootstrapOp::create(b, insertAfter->getLoc(), toBootstrap); @@ -89,9 +209,14 @@ struct ILPBootstrapPlacement } SmallVector valuesToBootstrap; + SmallVector + outputLevelReductions; + SmallVector + operandLevelReductions; auto result = module->walk([&](secret::GenericOp genericOp) { - if (failed( - processSecretGenericOp(genericOp, &solver, &valuesToBootstrap))) + if (failed(processSecretGenericOp(genericOp, &solver, &valuesToBootstrap, + &outputLevelReductions, + &operandLevelReductions))) return WalkResult::interrupt(); return WalkResult::advance(); }); @@ -100,6 +225,10 @@ struct ILPBootstrapPlacement return; } + // Insert per-use level reductions before consumers, matching Orbit-style + // edge rescale placement. + insertOperandLevelReductions(operandLevelReductions); + // Modreduce after every mul. insertModReduceBeforeOrAfterMult(getOperation(), /*afterMul=*/true, /*beforeMulIncludeFirstMul=*/false, @@ -108,6 +237,10 @@ struct ILPBootstrapPlacement // Relinearize after every mul. insertRelinearizeAfterMult(getOperation(), /*includeFloats=*/true); + // Insert shared producer-output level reductions after mul management and + // before bootstraps, matching Orbit's node rescale decisions. + insertOutputLevelReductions(outputLevelReductions); + // Insert bootstraps at the Values the ILP chose. Values remain valid. insertBootstrapsForValues(valuesToBootstrap); diff --git a/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.td b/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.td index f24e57e8ba..b9a3aba12c 100644 --- a/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.td +++ b/lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.td @@ -6,19 +6,23 @@ include "mlir/Pass/PassBase.td" def ILPBootstrapPlacement : Pass <"ilp-bootstrap-placement"> { let summary = "Optimize placement of bootstrap ops using ILP"; let description = [{ - This pass uses an integer linear program to determine the optimal level - of each term in the MLIR, and thus the placement of bootstrap and - modreduce operations. + This pass uses an integer linear program to determine feasible levels + and CKKS-style scales for each term in the MLIR, and thus the placement + of bootstrap and level-reduction operations. The pass runs on [ciphertext-semantic](https://heir.dev/docs/design/layout/#data-semantic-and-ciphertext-semantic-tensors) IR (secret.generic with arith ops operating on pre-packed tensors). It 1) Inserts mgmt.modreduce after each level-consuming op (e.g. mul in CKKS, where level drops only at multiplications). - 2) Inserts mgmt.bootstrap at the positions chosen by the ILP. - 3) Inserts mgmt.relinearize after each mul. Resulting order is mul -> + 2) Inserts per-use mgmt.level_reduce ops for edge rescale decisions + chosen by the ILP. + 3) Inserts mgmt.bootstrap at the positions chosen by the ILP. + 4) Inserts mgmt.relinearize after each mul. Resulting order is mul -> relinearize -> modreduce, with bootstrap after modreduce or after the op where the ILP chose. + The ILP formulation is inspired by [Orbit](https://eprint.iacr.org/2026/213.pdf). + Note: The ILP formulation does not account for a freshly encrypted ciphertext starting at a higher level than the bootstrap waterline. This will be implemented as future work. @@ -32,6 +36,41 @@ def ILPBootstrapPlacement : Pass <"ilp-bootstrap-placement"> { "int", /*default=*/"3", "Bootstrap waterline (max level). Levels are 0..bootstrap-waterline (inclusive); inputs start at bootstrap-waterline.">, + Option<"scaleWaterline", + "scale-waterline", + "int", + /*default=*/"40", + "Minimum CKKS scale budget used by the Orbit-inspired scale constraints.">, + Option<"scaleFactorBits", + "scale-factor-bits", + "int", + /*default=*/"51", + "Scale bits dropped by one rescale/modreduce in the Orbit-inspired scale constraints.">, + Option<"bootstrapLevelLowerBound", + "bootstrap-level-lower-bound", + "int", + /*default=*/"0", + "Minimum input level at which bootstrap is allowed in the Orbit-inspired scale constraints.">, + Option<"orbitCostModel", + "orbit-cost-model", + "std::string", + /*default=*/"""", + "Path to an Orbit JSON cost model. When provided, bootstrap-cost and rescale-cost are loaded from latencyTable.">, + Option<"bootstrapCost", + "bootstrap-cost", + "int", + /*default=*/"69320650", + "Cost of one bootstrap in the ILP objective. Default is the positive-latency average from Orbit's profiled 64k Lattigo base cost model.">, + Option<"rescaleCost", + "rescale-cost", + "int", + /*default=*/"40988", + "Cost of one level-reduction/rescale in the ILP objective. Default is the positive-latency average from Orbit's profiled 64k Lattigo base cost model.">, + Option<"useOrbitCompression", + "use-orbit-compression", + "bool", + /*default=*/"true", + "When true, add Orbit-inspired structural compression constraints so equivalent ops share ILP state and bootstrap decisions.">, ]; } diff --git a/tests/Transforms/ilp_bootstrap_placement/BUILD b/tests/Transforms/ilp_bootstrap_placement/BUILD index c571e6fc6d..b442021a1c 100644 --- a/tests/Transforms/ilp_bootstrap_placement/BUILD +++ b/tests/Transforms/ilp_bootstrap_placement/BUILD @@ -2,9 +2,20 @@ load("//bazel:lit.bzl", "glob_lit_tests") package(default_applicable_licenses = ["@heir//:license"]) +filegroup( + name = "orbit_cost_model", + srcs = [ + "orbit_bad_cost_model.json", + "orbit_cost_model.json", + ], +) + glob_lit_tests( name = "all_tests", - data = ["@heir//tests:test_utilities"], + data = [ + ":orbit_cost_model", + "@heir//tests:test_utilities", + ], driver = "@heir//tests:run_lit.sh", test_file_exts = ["mlir"], ) diff --git a/tests/Transforms/ilp_bootstrap_placement/bootstrap_placement_comparison.mlir b/tests/Transforms/ilp_bootstrap_placement/bootstrap_placement_comparison.mlir index 60dd96e480..75f1be730e 100644 --- a/tests/Transforms/ilp_bootstrap_placement/bootstrap_placement_comparison.mlir +++ b/tests/Transforms/ilp_bootstrap_placement/bootstrap_placement_comparison.mlir @@ -1,4 +1,5 @@ // RUN: heir-opt --ilp-bootstrap-placement=bootstrap-waterline=3 %s | FileCheck %s --check-prefix=CHECK-ILP +// RUN: heir-opt --ilp-bootstrap-placement="bootstrap-waterline=3 orbit-cost-model=%S/orbit_cost_model.json" %s | FileCheck %s --check-prefix=CHECK-COST-MODEL // RUN: heir-opt --secret-insert-mgmt-ckks=bootstrap-waterline=3 %s | FileCheck %s --check-prefix=CHECK-GREEDY // Compare the greedy bootstrap placement against the ILP bootstrap placement @@ -21,6 +22,8 @@ // CHECK-ILP-COUNT-2: mgmt.bootstrap // CHECK-ILP-NOT: mgmt.bootstrap // CHECK-GREEDY-COUNT-3: mgmt.bootstrap +// CHECK-COST-MODEL-COUNT-2: mgmt.bootstrap +// CHECK-COST-MODEL-NOT: mgmt.bootstrap func.func @bootstrap_placement_test( %arg0: !secret.secret>, diff --git a/tests/Transforms/ilp_bootstrap_placement/orbit_bad_cost_model.json b/tests/Transforms/ilp_bootstrap_placement/orbit_bad_cost_model.json new file mode 100644 index 0000000000..3d2c0a7cfd --- /dev/null +++ b/tests/Transforms/ilp_bootstrap_placement/orbit_bad_cost_model.json @@ -0,0 +1 @@ +{ "latencyTable": { "earth.bootstrap_single": [ diff --git a/tests/Transforms/ilp_bootstrap_placement/orbit_cost_model.json b/tests/Transforms/ilp_bootstrap_placement/orbit_cost_model.json new file mode 100644 index 0000000000..0234f3a921 --- /dev/null +++ b/tests/Transforms/ilp_bootstrap_placement/orbit_cost_model.json @@ -0,0 +1,42 @@ +{ + "bootstrapLevelLowerBound": 3, + "bootstrapLevelUpperBound": 16, + "latencyTable": { + "earth.bootstrap_single": [ + 0, + 0, + 0, + 69318421, + 69332297, + 69321224, + 69326425, + 69319004, + 69309546, + 69302417, + 69328521, + 69338641, + 69310061, + 69312692, + 69318897, + 69330304 + ], + "earth.rescale_single": [ + 12498, + 16562, + 20651, + 24715, + 28822, + 32854, + 36938, + 41005, + 45091, + 49153, + 53177, + 57229, + 61296, + 65332, + 69490 + ] + }, + "rescalingFactor": 51 +} diff --git a/tests/Transforms/ilp_bootstrap_placement/orbit_cost_model_error.mlir b/tests/Transforms/ilp_bootstrap_placement/orbit_cost_model_error.mlir new file mode 100644 index 0000000000..9816ede321 --- /dev/null +++ b/tests/Transforms/ilp_bootstrap_placement/orbit_cost_model_error.mlir @@ -0,0 +1,15 @@ +// RUN: not heir-opt --ilp-bootstrap-placement="orbit-cost-model=%S/does_not_exist.json" %s 2>&1 | FileCheck %s --check-prefix=MISSING +// RUN: not heir-opt --ilp-bootstrap-placement="orbit-cost-model=%S/orbit_bad_cost_model.json" %s 2>&1 | FileCheck %s --check-prefix=MALFORMED + +// MISSING: failed to load Orbit cost model +// MALFORMED: failed to load Orbit cost model + +func.func @orbit_cost_model_error( + %arg0: !secret.secret>) -> !secret.secret> { + %0 = secret.generic(%arg0: !secret.secret>) { + ^body(%input0: tensor<8xf32>): + %out = arith.addf %input0, %input0 : tensor<8xf32> + secret.yield %out : tensor<8xf32> + } -> !secret.secret> + return %0 : !secret.secret> +} diff --git a/tests/Transforms/ilp_bootstrap_placement/orbit_edge_rescale_placement.mlir b/tests/Transforms/ilp_bootstrap_placement/orbit_edge_rescale_placement.mlir new file mode 100644 index 0000000000..ba18baeba6 --- /dev/null +++ b/tests/Transforms/ilp_bootstrap_placement/orbit_edge_rescale_placement.mlir @@ -0,0 +1,39 @@ +// RUN: heir-opt --ilp-bootstrap-placement=bootstrap-waterline=3 %s | FileCheck %s --check-prefix=CHECK-RESCALE +// RUN: heir-opt --ilp-bootstrap-placement="bootstrap-waterline=3 bootstrap-cost=1 rescale-cost=100000000" %s | FileCheck %s --check-prefix=CHECK-BTS + +// The Orbit-style ILP can reduce a single incoming edge instead of forcing the +// producer SSA value itself to live at the consumer level. Here %input0 remains +// fresh for other uses, while the add consumes a level-reduced copy. The +// multiplication uses a plaintext constant so rescale is feasible with the +// default Orbit-style Sw=40 and scale-factor-bits=51 model. + +// CHECK-RESCALE: func.func @orbit_edge_rescale_placement +// CHECK-RESCALE: %[[PLAIN:.*]] = arith.constant +// CHECK-RESCALE: secret.generic +// CHECK-RESCALE: arith.mulf %input0, %[[PLAIN]] +// CHECK-RESCALE-NEXT: mgmt.modreduce +// CHECK-RESCALE-NOT: mgmt.bootstrap +// CHECK-RESCALE: %[[REDUCED:.*]] = mgmt.level_reduce %input0 +// CHECK-RESCALE: arith.addf %[[REDUCED]], +// CHECK-RESCALE-NOT: mgmt.bootstrap + +// CHECK-BTS: func.func @orbit_edge_rescale_placement +// CHECK-BTS: %[[PLAIN:.*]] = arith.constant +// CHECK-BTS: secret.generic +// CHECK-BTS: arith.mulf %input0, %[[PLAIN]] +// CHECK-BTS-NOT: mgmt.level_reduce %input0 +// CHECK-BTS: %[[BOOTSTRAPPED:.*]] = mgmt.bootstrap +// CHECK-BTS: arith.addf %input0, %[[BOOTSTRAPPED]] + +func.func @orbit_edge_rescale_placement( + %arg0: !secret.secret>) -> !secret.secret> { + %0 = secret.generic( + %arg0: !secret.secret>) { + ^body(%input0: tensor<8xf32>): + %plain = arith.constant dense<2.0> : tensor<8xf32> + %l1 = arith.mulf %input0, %plain : tensor<8xf32> + %out = arith.addf %input0, %l1 : tensor<8xf32> + secret.yield %out : tensor<8xf32> + } -> !secret.secret> + return %0 : !secret.secret> +} diff --git a/tests/Transforms/ilp_bootstrap_placement/orbit_output_rescale_placement.mlir b/tests/Transforms/ilp_bootstrap_placement/orbit_output_rescale_placement.mlir new file mode 100644 index 0000000000..1ed6496331 --- /dev/null +++ b/tests/Transforms/ilp_bootstrap_placement/orbit_output_rescale_placement.mlir @@ -0,0 +1,35 @@ +// RUN: heir-opt --ilp-bootstrap-placement=bootstrap-waterline=3 %s | FileCheck %s + +// Orbit-style node rescale can reduce a producer once when all outgoing uses +// need the lower level. The shared add result is reduced once before its +// consumers, rather than independently at each consumer. The multiplication +// uses a plaintext constant so rescale is feasible with the default Orbit-style +// Sw=40 and scale-factor-bits=51 model. + +// CHECK: func.func @orbit_output_rescale_placement +// CHECK: %[[PLAIN:.*]] = arith.constant +// CHECK: %[[SHARED:.*]] = arith.addf %input0, %input1 +// CHECK-NEXT: %[[REDUCED:.*]] = mgmt.level_reduce %[[SHARED]] +// CHECK: arith.mulf %input0, %[[PLAIN]] +// CHECK: arith.addf %[[REDUCED]], +// CHECK: arith.subf %[[REDUCED]], +// CHECK-NOT: mgmt.bootstrap + +func.func @orbit_output_rescale_placement( + %arg0: !secret.secret>, + %arg1: !secret.secret>) -> !secret.secret> { + %0 = secret.generic( + %arg0: !secret.secret>, + %arg1: !secret.secret>) { + ^body(%input0: tensor<8xf32>, + %input1: tensor<8xf32>): + %shared = arith.addf %input0, %input1 : tensor<8xf32> + %plain = arith.constant dense<2.0> : tensor<8xf32> + %l1 = arith.mulf %input0, %plain : tensor<8xf32> + %use0 = arith.addf %shared, %l1 : tensor<8xf32> + %use1 = arith.subf %shared, %l1 : tensor<8xf32> + %out = arith.addf %use0, %use1 : tensor<8xf32> + secret.yield %out : tensor<8xf32> + } -> !secret.secret> + return %0 : !secret.secret> +}