-
Notifications
You must be signed in to change notification settings - Fork 135
Orbit's Level-Scale-Aware ILP Bootstrap and Rescale Placement #2979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,12 +1,19 @@ | ||||||
| #include "lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.h" | ||||||
|
|
||||||
| #include <cmath> | ||||||
| #include <optional> | ||||||
| #include <string> | ||||||
| #include <utility> | ||||||
|
|
||||||
| #include "lib/Analysis/ILPBootstrapPlacementAnalysis/ILPBootstrapPlacementAnalysis.h" | ||||||
| #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" | ||||||
| #include "lib/Dialect/Mgmt/IR/MgmtOps.h" | ||||||
| #include "lib/Dialect/Mgmt/Transforms/AnnotateMgmt.h" | ||||||
| #include "lib/Dialect/Secret/IR/SecretOps.h" | ||||||
| #include "lib/Transforms/SecretInsertMgmt/Pipeline.h" | ||||||
| #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||||||
| #include "llvm/include/llvm/Support/JSON.h" // from @llvm-project | ||||||
| #include "llvm/include/llvm/Support/MemoryBuffer.h" // from @llvm-project | ||||||
| #include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project | ||||||
| #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project | ||||||
| #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project | ||||||
|
|
@@ -26,20 +33,90 @@ namespace heir { | |||||
| #define GEN_PASS_DEF_ILPBOOTSTRAPPLACEMENT | ||||||
| #include "lib/Transforms/ILPBootstrapPlacement/ILPBootstrapPlacement.h.inc" | ||||||
|
|
||||||
| struct OrbitCostModel { | ||||||
| int bootstrapCost; | ||||||
| int rescaleCost; | ||||||
| }; | ||||||
|
|
||||||
| static std::optional<int> averagePositiveLatency(const llvm::json::Object& root, | ||||||
| llvm::StringRef opName) { | ||||||
| const llvm::json::Object* latencyTable = root.getObject("latencyTable"); | ||||||
| if (!latencyTable) return std::nullopt; | ||||||
|
|
||||||
| const llvm::json::Array* latencies = latencyTable->getArray(opName); | ||||||
| if (!latencies) return std::nullopt; | ||||||
|
|
||||||
| double sum = 0; | ||||||
| int count = 0; | ||||||
| for (const llvm::json::Value& latencyValue : *latencies) { | ||||||
| std::optional<double> latency = latencyValue.getAsNumber(); | ||||||
| if (!latency || *latency <= 0) continue; | ||||||
| sum += *latency; | ||||||
| ++count; | ||||||
| } | ||||||
| if (count == 0) return std::nullopt; | ||||||
| return static_cast<int>(std::llround(sum / count)); | ||||||
| } | ||||||
|
|
||||||
| static FailureOr<OrbitCostModel> loadOrbitCostModel(llvm::StringRef path) { | ||||||
| auto bufferOrError = llvm::MemoryBuffer::getFile(path); | ||||||
| if (!bufferOrError) return failure(); | ||||||
|
|
||||||
| llvm::Expected<llvm::json::Value> parsed = | ||||||
| llvm::json::parse((*bufferOrError)->getBuffer()); | ||||||
| if (!parsed) { | ||||||
| llvm::consumeError(parsed.takeError()); | ||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| const llvm::json::Object* root = parsed->getAsObject(); | ||||||
| if (!root) return failure(); | ||||||
|
|
||||||
| std::optional<int> parsedBootstrapCost = | ||||||
| averagePositiveLatency(*root, "earth.bootstrap_single"); | ||||||
| std::optional<int> parsedRescaleCost = | ||||||
| averagePositiveLatency(*root, "earth.rescale_single"); | ||||||
| if (!parsedBootstrapCost || !parsedRescaleCost) return failure(); | ||||||
|
|
||||||
| return OrbitCostModel{*parsedBootstrapCost, *parsedRescaleCost}; | ||||||
| } | ||||||
|
|
||||||
| struct ILPBootstrapPlacement | ||||||
| : impl::ILPBootstrapPlacementBase<ILPBootstrapPlacement> { | ||||||
| using ILPBootstrapPlacementBase::ILPBootstrapPlacementBase; | ||||||
|
|
||||||
| LogicalResult processSecretGenericOp( | ||||||
| secret::GenericOp genericOp, DataFlowSolver* solver, | ||||||
| SmallVector<Value, 32>* valuesToBootstrap) { | ||||||
| SmallVector<Value, 32>* valuesToBootstrap, | ||||||
| SmallVector<ILPBootstrapPlacementAnalysis::OutputLevelReduction, 32>* | ||||||
| outputLevelReductions, | ||||||
| SmallVector<ILPBootstrapPlacementAnalysis::OperandLevelReduction, 32>* | ||||||
| operandLevelReductions) { | ||||||
| genericOp->walk([&](mgmt::BootstrapOp op) { | ||||||
| op.getResult().replaceAllUsesWith(op.getOperand()); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sanity check: should you also remove any existing mgmt.modreduce and mgmt.level_reduce ops? |
||||||
| op.erase(); | ||||||
| }); | ||||||
|
|
||||||
| ILPBootstrapPlacementAnalysis analysis(genericOp, solver, | ||||||
| bootstrapWaterline); | ||||||
| int effectiveBootstrapCost = bootstrapCost; | ||||||
| int effectiveRescaleCost = rescaleCost; | ||||||
| if (!orbitCostModel.empty()) { | ||||||
| FailureOr<OrbitCostModel> loadedCostModel = | ||||||
| loadOrbitCostModel(orbitCostModel); | ||||||
| if (failed(loadedCostModel)) { | ||||||
| llvm::errs() << "failed to load Orbit cost model from `" | ||||||
| << orbitCostModel << "`\n"; | ||||||
| genericOp->emitError() << "failed to load Orbit cost model from `" | ||||||
| << orbitCostModel << "`"; | ||||||
|
Comment on lines
+108
to
+109
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Nit, unnecessary. |
||||||
| return failure(); | ||||||
| } | ||||||
| effectiveBootstrapCost = loadedCostModel->bootstrapCost; | ||||||
| effectiveRescaleCost = loadedCostModel->rescaleCost; | ||||||
| } | ||||||
|
|
||||||
| ILPBootstrapPlacementAnalysis analysis( | ||||||
| genericOp, solver, bootstrapWaterline, scaleWaterline, scaleFactorBits, | ||||||
| bootstrapLevelLowerBound, effectiveBootstrapCost, effectiveRescaleCost, | ||||||
| useOrbitCompression); | ||||||
| if (failed(analysis.solve())) { | ||||||
| genericOp->emitError( | ||||||
| "Failed to solve the bootstrap placement optimization problem"); | ||||||
|
|
@@ -48,26 +125,69 @@ struct ILPBootstrapPlacement | |||||
| LLVM_DEBUG(analysis.printSolution(llvm::dbgs())); | ||||||
| for (Value v : analysis.getValuesToBootstrap()) | ||||||
| valuesToBootstrap->push_back(v); | ||||||
| for (auto reduction : analysis.getOutputLevelReductions()) | ||||||
| outputLevelReductions->push_back(reduction); | ||||||
| for (auto reduction : analysis.getOperandLevelReductions()) | ||||||
| operandLevelReductions->push_back(reduction); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| std::pair<Value, Operation*> followRelinearizeModReduceChain(Value value) { | ||||||
| Value chainValue = value; | ||||||
| Operation* chainEnd = value.getDefiningOp(); | ||||||
| while (chainValue.hasOneUse()) { | ||||||
| Operation* user = *chainValue.getUsers().begin(); | ||||||
| if (isa<mgmt::RelinearizeOp>(user) || isa<mgmt::ModReduceOp>(user)) { | ||||||
| chainValue = user->getResult(0); | ||||||
| chainEnd = user; | ||||||
| continue; | ||||||
| } | ||||||
| break; | ||||||
| } | ||||||
| return {chainValue, chainEnd}; | ||||||
| } | ||||||
|
|
||||||
| void insertOutputLevelReductions( | ||||||
| ArrayRef<ILPBootstrapPlacementAnalysis::OutputLevelReduction> | ||||||
| outputLevelReductions) { | ||||||
| OpBuilder b(&getContext()); | ||||||
| for (auto reduction : outputLevelReductions) { | ||||||
| auto [toReduce, insertAfter] = | ||||||
| followRelinearizeModReduceChain(reduction.value); | ||||||
| if (!insertAfter) continue; | ||||||
|
|
||||||
| b.setInsertionPointAfter(insertAfter); | ||||||
| auto levelReduceOp = mgmt::LevelReduceOp::create( | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| b, insertAfter->getLoc(), toReduce, | ||||||
| b.getI64IntegerAttr(reduction.levelToDrop)); | ||||||
| toReduce.replaceAllUsesExcept(levelReduceOp.getResult(), {levelReduceOp}); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| void insertOperandLevelReductions( | ||||||
| ArrayRef<ILPBootstrapPlacementAnalysis::OperandLevelReduction> | ||||||
| operandLevelReductions) { | ||||||
| OpBuilder b(&getContext()); | ||||||
| for (auto reduction : operandLevelReductions) { | ||||||
| Operation* op = reduction.op; | ||||||
| if (!op || reduction.operandNumber >= op->getNumOperands()) continue; | ||||||
|
|
||||||
| Value operand = op->getOperand(reduction.operandNumber); | ||||||
| b.setInsertionPoint(op); | ||||||
| auto levelReduceOp = mgmt::LevelReduceOp::create( | ||||||
| b, op->getLoc(), operand, b.getI64IntegerAttr(reduction.levelToDrop)); | ||||||
| op->setOperand(reduction.operandNumber, levelReduceOp.getResult()); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| void insertBootstrapsForValues(ArrayRef<Value> valuesToBootstrap) { | ||||||
| OpBuilder b(&getContext()); | ||||||
| for (Value v : valuesToBootstrap) { | ||||||
| // After modreduce/relinearize we have mul -> relinearize -> modreduce. | ||||||
| // Follow the chain so we bootstrap the modreduce result (correct level | ||||||
| // refresh) and insert after it. | ||||||
| Value toBootstrap = v; | ||||||
| Operation* insertAfter = v.getDefiningOp(); | ||||||
| while (toBootstrap.hasOneUse()) { | ||||||
| Operation* user = *toBootstrap.getUsers().begin(); | ||||||
| if (isa<mgmt::RelinearizeOp>(user) || isa<mgmt::ModReduceOp>(user)) { | ||||||
| toBootstrap = user->getResult(0); | ||||||
| insertAfter = user; | ||||||
| } else { | ||||||
| break; | ||||||
| } | ||||||
| } | ||||||
| auto [toBootstrap, insertAfter] = followRelinearizeModReduceChain(v); | ||||||
| if (!insertAfter) continue; | ||||||
| b.setInsertionPointAfter(insertAfter); | ||||||
| auto bootstrapOp = | ||||||
| mgmt::BootstrapOp::create(b, insertAfter->getLoc(), toBootstrap); | ||||||
|
|
@@ -89,9 +209,14 @@ struct ILPBootstrapPlacement | |||||
| } | ||||||
|
|
||||||
| SmallVector<Value, 32> valuesToBootstrap; | ||||||
| SmallVector<ILPBootstrapPlacementAnalysis::OutputLevelReduction, 32> | ||||||
| outputLevelReductions; | ||||||
| SmallVector<ILPBootstrapPlacementAnalysis::OperandLevelReduction, 32> | ||||||
| operandLevelReductions; | ||||||
| auto result = module->walk([&](secret::GenericOp genericOp) { | ||||||
| if (failed( | ||||||
| processSecretGenericOp(genericOp, &solver, &valuesToBootstrap))) | ||||||
| if (failed(processSecretGenericOp(genericOp, &solver, &valuesToBootstrap, | ||||||
| &outputLevelReductions, | ||||||
| &operandLevelReductions))) | ||||||
| return WalkResult::interrupt(); | ||||||
| return WalkResult::advance(); | ||||||
| }); | ||||||
|
|
@@ -100,6 +225,10 @@ struct ILPBootstrapPlacement | |||||
| return; | ||||||
| } | ||||||
|
|
||||||
| // Insert per-use level reductions before consumers, matching Orbit-style | ||||||
| // edge rescale placement. | ||||||
| insertOperandLevelReductions(operandLevelReductions); | ||||||
|
|
||||||
| // Modreduce after every mul. | ||||||
| insertModReduceBeforeOrAfterMult(getOperation(), /*afterMul=*/true, | ||||||
| /*beforeMulIncludeFirstMul=*/false, | ||||||
|
|
@@ -108,6 +237,10 @@ struct ILPBootstrapPlacement | |||||
| // Relinearize after every mul. | ||||||
| insertRelinearizeAfterMult(getOperation(), /*includeFloats=*/true); | ||||||
|
|
||||||
| // Insert shared producer-output level reductions after mul management and | ||||||
| // before bootstraps, matching Orbit's node rescale decisions. | ||||||
| insertOutputLevelReductions(outputLevelReductions); | ||||||
|
|
||||||
| // Insert bootstraps at the Values the ILP chose. Values remain valid. | ||||||
| insertBootstrapsForValues(valuesToBootstrap); | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a hope to make this more fine-grained than using a blended cost? If not, perhaps the json schema itself should just provide a blended latency cost for each op. I think supporting a per-level cost would require pretty dramatic changes to the model.