Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
486 changes: 386 additions & 100 deletions lib/Dialect/LWE/Transforms/AddDebugPort.cpp

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions lib/Dialect/LWE/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ cc_library(
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect:FuncUtils",
"@heir//lib/Dialect/Debug/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Utils:TransformUtils",
"@heir//lib/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)

Expand Down
13 changes: 12 additions & 1 deletion lib/Dialect/LWE/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ def AddDebugPort : Pass<"lwe-add-debug-port"> {
function. The debug ports are declarations and user should provide functions with
the same name in their code.

If the option `insert-debug-after-every-op` is set to true, it will insert a `debug.validate`
op after every homomorphic operation.

Regardless of the `insert-debug-after-every-op` option, this pass will lower all
`debug.validate` ops it encounters to function calls.

For example, if the function is called "foo", the secret key is added to its
arguments, and the debug port is called after each homomorphic operation:
```mlir
Expand All @@ -29,13 +35,18 @@ def AddDebugPort : Pass<"lwe-add-debug-port"> {
}
```
}];
let dependentDialects = ["mlir::heir::lwe::LWEDialect"];
let dependentDialects = [
"mlir::heir::lwe::LWEDialect",
"mlir::heir::debug::DebugDialect"
];
let options = [
Option<"entryFunction", "entry-function", "std::string",
/*default=*/"", "Default entry function "
"name of entry function.">,
Option<"messageSize", "message-size", "int",
/*default=*/"1", "The size of the message in the ciphertext.">,
Option<"insertDebugAfterEveryOp", "insert-debug-after-every-op", "bool",
/*default=*/"false", "Whether to add debug ports after every op">
];
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Secret/Conversions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_library(
hdrs = ["Patterns.h"],
deps = [
"@heir//lib/Dialect:ModuleAttributes",
"@heir//lib/Dialect/Debug/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
Expand Down
51 changes: 33 additions & 18 deletions lib/Dialect/Secret/Conversions/Patterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <cstddef>
#include <cstdint>

#include "lib/Dialect/Debug/IR/DebugDialect.h"
#include "lib/Dialect/Debug/IR/DebugOps.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWEDialect.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
Expand Down Expand Up @@ -502,6 +504,20 @@ LogicalResult ConvertEmpty::matchAndRewrite(
return success();
}

// This only needs a special pattern because it has attributes that aren't
// copied over by the base SecretGenericOpConversion.
FailureOr<Operation*> ConvertDebugValidate::matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ContextAwareConversionPatternRewriter& rewriter) const {
debug::ValidateOp innerOp =
cast<debug::ValidateOp>(op.getBody()->getOperations().front());
debug::ValidateOp newOp = debug::ValidateOp::create(
rewriter, op.getLoc(), outputTypes, inputs, innerOp->getAttrs());
rewriter.replaceOp(op, newOp);
return newOp.getOperation();
}

bool hasSecretOperandsOrResults(Operation* op) {
return llvm::any_of(op->getOperands(),
[](Value operand) {
Expand All @@ -527,24 +543,23 @@ void addSecretToSchemeDefaultConversionTargetsAndPatterns(
target.markUnknownOpDynamicallyLegal(
[&](Operation* op) { return !hasSecretOperandsOrResults(op); });

patterns.add<SecretGenericOpIdentityConversion<arith::ExtUIOp>,
SecretGenericOpIdentityConversion<arith::ExtSIOp>,
SecretGenericOpIdentityConversion<arith::FPToSIOp>,
SecretGenericOpIdentityConversion<arith::FPToUIOp>,
SecretGenericOpIdentityConversion<arith::SIToFPOp>,
SecretGenericOpIdentityConversion<arith::UIToFPOp>,
SecretGenericOpConversion<tensor::EmptyOp, tensor::EmptyOp>,
SecretGenericFuncCallConversion, ConvertExtractSlice,
ConvertInsertSlice, ConvertAnyContextAware<affine::AffineForOp>,
ConvertAnyContextAware<affine::AffineIfOp>,
ConvertAnyContextAware<affine::AffineYieldOp>,
ConvertAnyContextAware<scf::ForOp>,
ConvertAnyContextAware<scf::IfOp>,
ConvertAnyContextAware<scf::YieldOp>,
ConvertAnyContextAware<tensor::ExtractOp>,
ConvertAnyContextAware<tensor::InsertOp>,
ConvertAnyContextAware<func::CallOp>>(typeConverter,
patterns.getContext());
patterns.add<
ConvertAnyContextAware<affine::AffineForOp>,
ConvertAnyContextAware<affine::AffineIfOp>,
ConvertAnyContextAware<affine::AffineYieldOp>,
ConvertAnyContextAware<func::CallOp>, ConvertAnyContextAware<scf::ForOp>,
ConvertAnyContextAware<scf::IfOp>, ConvertAnyContextAware<scf::YieldOp>,
ConvertAnyContextAware<tensor::ExtractOp>,
ConvertAnyContextAware<tensor::InsertOp>, ConvertDebugValidate,
ConvertExtractSlice, ConvertInsertSlice, SecretGenericFuncCallConversion,
SecretGenericOpConversion<tensor::EmptyOp, tensor::EmptyOp>,
SecretGenericOpIdentityConversion<arith::ExtSIOp>,
SecretGenericOpIdentityConversion<arith::ExtUIOp>,
SecretGenericOpIdentityConversion<arith::FPToSIOp>,
SecretGenericOpIdentityConversion<arith::FPToUIOp>,
SecretGenericOpIdentityConversion<arith::SIToFPOp>,
SecretGenericOpIdentityConversion<arith::UIToFPOp>>(
typeConverter, patterns.getContext());

addStructuralConversionPatterns(typeConverter, patterns, target);
}
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Secret/Conversions/Patterns.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LIB_DIALECT_SECRET_CONVERSIONS_PATTERNS_H_
#define LIB_DIALECT_SECRET_CONVERSIONS_PATTERNS_H_

#include "lib/Dialect/Debug/IR/DebugOps.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
Expand Down Expand Up @@ -90,6 +91,17 @@ struct ConvertInsertSlice
ContextAwareConversionPatternRewriter& rewriter) const override;
};

struct ConvertDebugValidate
: public SecretGenericOpConversion<debug::ValidateOp, debug::ValidateOp> {
using SecretGenericOpConversion<debug::ValidateOp,
debug::ValidateOp>::SecretGenericOpConversion;

FailureOr<Operation*> matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ContextAwareConversionPatternRewriter& rewriter) const override;
};

// An empty ciphertext-semantic tensor can be used as the initializer of a
// reduction. In this case, there is no containing secret.generic op, and we
// anchor on the subsequent `mgmt::InitOp` to determine how to convert it to a
Expand Down
101 changes: 61 additions & 40 deletions lib/Dialect/Secret/Transforms/AddDebugPort.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#include "lib/Dialect/Secret/Transforms/AddDebugPort.h"

#include <memory>
#include <string>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Debug/IR/DebugOps.h"
#include "lib/Dialect/FuncUtils.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
Expand Down Expand Up @@ -46,64 +50,81 @@ func::FuncOp getOrCreateExternalDebugFunc(ModuleOp module, Type valueType) {
return funcOp;
}

LogicalResult insertExternalCall(secret::GenericOp op, DataFlowSolver& solver) {
auto module = op->getParentOfType<ModuleOp>();

ImplicitLocOpBuilder b =
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody());

auto insertCall = [&](Value value) {
Type valueType = value.getType();

func::CallOp::create(b, getOrCreateExternalDebugFunc(module, valueType),
ArrayRef<Value>{value});
void insertValidationOps(secret::GenericOp op, DataFlowSolver& solver) {
int count = 0;
auto insertValidate = [&](Value value, OpBuilder& b) {
if (isSecret(value, &solver)) {
debug::ValidateOp::create(b, value.getLoc(), value,
"heir_debug_" + std::to_string(count++),
nullptr);
}
};

// insert for each argument
for (auto arg : op.getBody()->getArguments()) {
if (!isSecret(arg, &solver)) {
continue;
}
insertCall(arg);
Block* body = op.getBody();
OpBuilder argBuilder(body, body->begin());
for (auto arg : body->getArguments()) {
insertValidate(arg, argBuilder);
}

// insert after each op
op.walk([&](Operation* op) {
if (mlir::isa<secret::GenericOp>(op)) {
op.walk([&](Operation* walkOp) {
if (walkOp == op.getOperation() || mlir::isa<secret::GenericOp>(walkOp) ||
walkOp->hasTrait<OpTrait::IsTerminator>()) {
return;
}

b.setInsertionPointAfter(op);
for (Value result : op->getResults()) {
if (!isSecret(result, &solver)) {
continue;
}
insertCall(result);
OpBuilder opBuilder(walkOp->getBlock(), ++walkOp->getIterator());
for (Value result : walkOp->getResults()) {
insertValidate(result, opBuilder);
}
});
}

void lowerValidationOps(secret::GenericOp op) {
auto module = op->getParentOfType<ModuleOp>();
op.walk([&](debug::ValidateOp validateOp) {
Value value = validateOp.getInput();
ImplicitLocOpBuilder b(validateOp.getLoc(), validateOp);

auto callOp = b.create<func::CallOp>(
getOrCreateExternalDebugFunc(module, value.getType()),
ArrayRef<Value>{value});

// Transfer attributes
callOp->setAttr("debug.name", validateOp.getNameAttr());
if (validateOp.getMetadata()) {
callOp->setAttr("debug.metadata", validateOp.getMetadataAttr());
}

validateOp.erase();
});
return success();
}

struct AddDebugPort : impl::SecretAddDebugPortBase<AddDebugPort> {
using SecretAddDebugPortBase::SecretAddDebugPortBase;

void runOnOperation() override {
DataFlowSolver solver;
dataflow::loadBaselineAnalyses(solver);
solver.load<SecretnessAnalysis>();

auto result = solver.initializeAndRun(getOperation());
if (failed(result)) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
std::unique_ptr<DataFlowSolver> solver;

// No need to do a secretness analysis if we're not inserting new
// debug.validate ops
if (insertDebugAfterEveryOp) {
solver = std::make_unique<DataFlowSolver>();
dataflow::loadBaselineAnalyses(*solver);
solver->load<SecretnessAnalysis>();

auto result = solver->initializeAndRun(getOperation());
if (failed(result)) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}
}

getOperation()->walk([&](secret::GenericOp genericOp) {
if (failed(insertExternalCall(genericOp, solver))) {
genericOp->emitError("Failed to add debug port for genericOp");
signalPassFailure();
if (insertDebugAfterEveryOp) {
insertValidationOps(genericOp, *solver);
}
lowerValidationOps(genericOp);
});
}
};
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Secret/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ cc_library(
":pass_inc_gen",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect:FuncUtils",
"@heir//lib/Dialect/Debug/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
29 changes: 25 additions & 4 deletions lib/Dialect/Secret/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,23 @@ def SecretExtractGenericBody : Pass<"secret-extract-generic-body"> {
def SecretAddDebugPort : Pass<"secret-add-debug-port"> {
let summary = "Add debug port to secret-arithmetic ops";
let description = [{
This pass adds debug ports to secret-arithmetic ops in the IR, namely operations
wrapped by secret.generic. The debug ports are prefixed with "__heir_debug" and
are invoked after each operation in the generic body. The debug ports are
declarations and user should provide functions with the same name in their code.
This pass adds debug ports to secret-arithmetic ops in the IR, namely
operations wrapped by `secret.generic`. The debug ports are prefixed with
`"__heir_debug"` and are invoked after each operation in the generic body.
The debug ports are declarations and the user must define functions with the
same name in their application code.

If the option `insert-debug-after-every-op` is set to true, it will insert a
call after every operation in the generic body.

Regardless of the `insert-debug-after-every-op` option, this pass lowers all
`debug.validate` ops it encounters in the generic body to function calls.
In this way, the user can provide specific checkpoints in a program to invoke
debug code.

For example, if the function is called "foo", the debug port is called after
each homomorphic operation:

```mlir
// declaration of external debug function
func.func private @__heir_debug_tensor_8xi16_(tensor<8xi16>)
Expand All @@ -124,7 +134,18 @@ def SecretAddDebugPort : Pass<"secret-add-debug-port"> {
}
}
```

Because this pass is agnostic of the cryptosystem backend, it does not insert
secret key material into the debug handler call. See `lwe-add-debug-port` for
cryptosystem-aware analogues of this pass. This pass must remain, in
particular, to support the plaintext pipeline.
}];

let options = [
Option<"insertDebugAfterEveryOp", "insert-debug-after-every-op", "bool",
/*default=*/"false", "Whether to add debug calls after every op">,
];
let dependentDialects = ["mlir::heir::debug::DebugDialect"];
}

def SecretImportExecutionResult : Pass<"secret-import-execution-result"> {
Expand Down
15 changes: 14 additions & 1 deletion lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,20 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::CallOp op) {
<< definingOp->getName() << "\";\n";
}
// Use AsmPrinter to print Value
os << debugAttrMapName << R"(["asm.result_ssa_format"] = ")" << ciphertext
std::string ssaFormat;
llvm::raw_string_ostream ss(ssaFormat);
ss << ciphertext;
std::string escaped;
for (char c : ssaFormat) {
if (c == '\n') {
escaped += "\\n";
} else if (c == '"') {
escaped += "\\\"";
} else {
escaped += c;
}
}
os << debugAttrMapName << R"(["asm.result_ssa_format"] = ")" << escaped
<< "\";\n";
}

Expand Down
Loading
Loading