diff --git a/lib/Dialect/Secret/Transforms/AddDebugPort.cpp b/lib/Dialect/Secret/Transforms/AddDebugPort.cpp index fae2f10242..03a815aaee 100644 --- a/lib/Dialect/Secret/Transforms/AddDebugPort.cpp +++ b/lib/Dialect/Secret/Transforms/AddDebugPort.cpp @@ -1,16 +1,20 @@ #include "lib/Dialect/Secret/Transforms/AddDebugPort.h" +#include #include #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "lib/Dialect/Debug/IR/DebugOps.h" #include "lib/Dialect/FuncUtils.h" #include "lib/Dialect/Secret/IR/SecretOps.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project @@ -46,64 +50,81 @@ func::FuncOp getOrCreateExternalDebugFunc(ModuleOp module, Type valueType) { return funcOp; } -LogicalResult insertExternalCall(secret::GenericOp op, DataFlowSolver& solver) { - auto module = op->getParentOfType(); - - ImplicitLocOpBuilder b = - ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody()); - - auto insertCall = [&](Value value) { - Type valueType = value.getType(); - - func::CallOp::create(b, getOrCreateExternalDebugFunc(module, valueType), - ArrayRef{value}); +void insertValidationOps(secret::GenericOp op, DataFlowSolver& solver) { + int count = 0; + auto insertValidate = [&](Value value, OpBuilder& b) { + if (isSecret(value, &solver)) { + debug::ValidateOp::create(b, value.getLoc(), value, + "heir_debug_" + std::to_string(count++), + nullptr); + } }; - // insert for each argument - for (auto arg : op.getBody()->getArguments()) { - if (!isSecret(arg, &solver)) { - continue; - } - insertCall(arg); + Block* body = op.getBody(); + OpBuilder argBuilder(body, body->begin()); + for (auto arg : body->getArguments()) { + insertValidate(arg, argBuilder); } - // insert after each op - op.walk([&](Operation* op) { - if (mlir::isa(op)) { + op.walk([&](Operation* walkOp) { + if (walkOp == op.getOperation() || mlir::isa(walkOp) || + walkOp->hasTrait()) { return; } - b.setInsertionPointAfter(op); - for (Value result : op->getResults()) { - if (!isSecret(result, &solver)) { - continue; - } - insertCall(result); + OpBuilder opBuilder(walkOp->getBlock(), ++walkOp->getIterator()); + for (Value result : walkOp->getResults()) { + insertValidate(result, opBuilder); + } + }); +} + +void lowerValidationOps(secret::GenericOp op) { + auto module = op->getParentOfType(); + op.walk([&](debug::ValidateOp validateOp) { + Value value = validateOp.getInput(); + ImplicitLocOpBuilder b(validateOp.getLoc(), validateOp); + + auto callOp = b.create( + getOrCreateExternalDebugFunc(module, value.getType()), + ArrayRef{value}); + + // Transfer attributes + callOp->setAttr("debug.name", validateOp.getNameAttr()); + if (validateOp.getMetadata()) { + callOp->setAttr("debug.metadata", validateOp.getMetadataAttr()); } + + validateOp.erase(); }); - return success(); } struct AddDebugPort : impl::SecretAddDebugPortBase { using SecretAddDebugPortBase::SecretAddDebugPortBase; void runOnOperation() override { - DataFlowSolver solver; - dataflow::loadBaselineAnalyses(solver); - solver.load(); - - auto result = solver.initializeAndRun(getOperation()); - if (failed(result)) { - getOperation()->emitOpError() << "Failed to run the analysis.\n"; - signalPassFailure(); - return; + std::unique_ptr solver; + + // No need to do a secretness analysis if we're not inserting new + // debug.validate ops + if (insertDebugAfterEveryOp) { + solver = std::make_unique(); + dataflow::loadBaselineAnalyses(*solver); + solver->load(); + + auto result = solver->initializeAndRun(getOperation()); + if (failed(result)) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } } getOperation()->walk([&](secret::GenericOp genericOp) { - if (failed(insertExternalCall(genericOp, solver))) { - genericOp->emitError("Failed to add debug port for genericOp"); - signalPassFailure(); + if (insertDebugAfterEveryOp) { + insertValidationOps(genericOp, *solver); } + lowerValidationOps(genericOp); }); } }; diff --git a/lib/Dialect/Secret/Transforms/BUILD b/lib/Dialect/Secret/Transforms/BUILD index 1093ae1d38..d0b625cbe7 100644 --- a/lib/Dialect/Secret/Transforms/BUILD +++ b/lib/Dialect/Secret/Transforms/BUILD @@ -178,7 +178,9 @@ cc_library( ":pass_inc_gen", "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect:FuncUtils", + "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/Secret/Transforms/Passes.td b/lib/Dialect/Secret/Transforms/Passes.td index 15fa24be71..6b7cfe8487 100644 --- a/lib/Dialect/Secret/Transforms/Passes.td +++ b/lib/Dialect/Secret/Transforms/Passes.td @@ -103,13 +103,23 @@ def SecretExtractGenericBody : Pass<"secret-extract-generic-body"> { def SecretAddDebugPort : Pass<"secret-add-debug-port"> { let summary = "Add debug port to secret-arithmetic ops"; let description = [{ - This pass adds debug ports to secret-arithmetic ops in the IR, namely operations - wrapped by secret.generic. The debug ports are prefixed with "__heir_debug" and - are invoked after each operation in the generic body. The debug ports are - declarations and user should provide functions with the same name in their code. + This pass adds debug ports to secret-arithmetic ops in the IR, namely + operations wrapped by `secret.generic`. The debug ports are prefixed with + `"__heir_debug"` and are invoked after each operation in the generic body. + The debug ports are declarations and the user must define functions with the + same name in their application code. + + If the option `insert-debug-after-every-op` is set to true, it will insert a + call after every operation in the generic body. + + Regardless of the `insert-debug-after-every-op` option, this pass lowers all + `debug.validate` ops it encounters in the generic body to function calls. + In this way, the user can provide specific checkpoints in a program to invoke + debug code. For example, if the function is called "foo", the debug port is called after each homomorphic operation: + ```mlir // declaration of external debug function func.func private @__heir_debug_tensor_8xi16_(tensor<8xi16>) @@ -124,7 +134,18 @@ def SecretAddDebugPort : Pass<"secret-add-debug-port"> { } } ``` + + Because this pass is agnostic of the cryptosystem backend, it does not insert + secret key material into the debug handler call. See `lwe-add-debug-port` for + cryptosystem-aware analogues of this pass. This pass must remain, in + particular, to support the plaintext pipeline. }]; + + let options = [ + Option<"insertDebugAfterEveryOp", "insert-debug-after-every-op", "bool", + /*default=*/"false", "Whether to add debug calls after every op">, + ]; + let dependentDialects = ["mlir::heir::debug::DebugDialect"]; } def SecretImportExecutionResult : Pass<"secret-import-execution-result"> { diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index 1fefe5e274..255ec075ae 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -223,7 +223,9 @@ void mlirToPlaintextPipelineBuilder(OpPassManager& pm, if (options.debug) { // Insert debug handler calls - pm.addPass(secret::createSecretAddDebugPort()); + secret::SecretAddDebugPortOptions debugOptions; + debugOptions.insertDebugAfterEveryOp = true; + pm.addPass(secret::createSecretAddDebugPort(debugOptions)); } pm.addPass(secret::createSecretDistributeGeneric()); diff --git a/tests/Dialect/Debug/Transforms/lower_validate_secret.mlir b/tests/Dialect/Debug/Transforms/lower_validate_secret.mlir new file mode 100644 index 0000000000..b7fe825640 --- /dev/null +++ b/tests/Dialect/Debug/Transforms/lower_validate_secret.mlir @@ -0,0 +1,16 @@ +// RUN: heir-opt --secret-add-debug-port %s | FileCheck %s + +module { + func.func @test_lower_validate(%arg0: !secret.secret) -> !secret.secret { + %0 = secret.generic(%arg0: !secret.secret) { + ^body(%arg1: i32): + debug.validate %arg1 {name = "val1"} : i32 + secret.yield %arg1 : i32 + } -> !secret.secret + return %0 : !secret.secret + } +} + +// CHECK: func.func private @__heir_debug_i32(i32) +// CHECK: func.func @test_lower_validate +// CHECK: call @__heir_debug_i32({{.*}}) {debug.name = "val1"}