diff --git a/lib/Dialect/LWE/Transforms/AddDebugPort.cpp b/lib/Dialect/LWE/Transforms/AddDebugPort.cpp index 5444d1674f..e6c9e79872 100644 --- a/lib/Dialect/LWE/Transforms/AddDebugPort.cpp +++ b/lib/Dialect/LWE/Transforms/AddDebugPort.cpp @@ -1,24 +1,30 @@ #include "lib/Dialect/LWE/Transforms/AddDebugPort.h" #include +#include +#include "lib/Dialect/Debug/IR/DebugOps.h" #include "lib/Dialect/LWE/IR/LWETypes.h" -#include "lib/Utils/TransformUtils.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project +#include "lib/Utils/Utils.h" +#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project namespace mlir { namespace heir { @@ -28,17 +34,28 @@ namespace lwe { #include "lib/Dialect/LWE/Transforms/Passes.h.inc" FailureOr getPrivateKeyType(func::FuncOp op) { - const auto* type = llvm::find_if(op.getArgumentTypes(), [](Type type) { - return mlir::isa(getElementTypeOrSelf(type)); - }); + SmallVector ciphertextTypes; + for (Type type : op.getArgumentTypes()) { + Type elementType = getElementTypeOrSelf(type); + if (isa(elementType)) { + ciphertextTypes.push_back(elementType); + } + } - if (type == op.getArgumentTypes().end()) { - return op.emitError( - "Function does not have an argument of LWECiphertextType"); + if (ciphertextTypes.empty()) { + return failure(); } - auto lweCiphertextType = cast(getElementTypeOrSelf(*type)); + if (!llvm::all_equal(ciphertextTypes)) { + op.emitWarning( + "Conflicting ciphertext types found among function arguments"); + } + // Fallback to the first ciphertext type. This is acceptable assuming that + // all ciphertexts in the function are intended to be decrypted by the same + // key, or at least that the key derived from the first ciphertext is + // sufficient for debugging purposes. + auto lweCiphertextType = cast(ciphertextTypes[0]); auto lwePrivateKeyType = LWESecretKeyType::get(op.getContext(), lweCiphertextType.getKey(), lweCiphertextType.getCiphertextSpace().getRing()); @@ -46,17 +63,28 @@ FailureOr getPrivateKeyType(func::FuncOp op) { } func::FuncOp getOrCreateExternalDebugFunc( - ModuleOp module, Type lwePrivateKeyType, - LWECiphertextType lweCiphertextType, const DenseMap& typeToInt) { - std::string funcName = - "__heir_debug_" + std::to_string(typeToInt.at(lweCiphertextType)); + ModuleOp module, Type lwePrivateKeyType, Type valueType, + llvm::DenseMap, int>& typePairToInt) { + auto key = std::make_pair(lwePrivateKeyType, valueType); + if (typePairToInt.find(key) == typePairToInt.end()) { + typePairToInt[key] = typePairToInt.size(); + } auto* context = module.getContext(); - auto lookup = module.lookupSymbol(funcName); - if (lookup) return lookup; - auto debugFuncType = - FunctionType::get(context, {lwePrivateKeyType, lweCiphertextType}, {}); + FunctionType::get(context, {lwePrivateKeyType, valueType}, {}); + + int counter = typePairToInt[key]; + std::string funcName = "__heir_debug_" + std::to_string(counter); + + while (auto lookup = module.lookupSymbol(funcName)) { + if (lookup.getFunctionType() == debugFuncType) return lookup; + // Name conflict with different type, try next name + funcName = "__heir_debug_" + std::to_string(++counter); + } + + // Update the map with the actual counter used, to avoid searching again + typePairToInt[key] = counter; ImplicitLocOpBuilder b = ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody()); @@ -66,100 +94,358 @@ func::FuncOp getOrCreateExternalDebugFunc( return funcOp; } -LogicalResult insertExternalCall(func::FuncOp op, Type lwePrivateKeyType, - int messageSize) { - auto module = op->getParentOfType(); +void insertValidationOps(func::FuncOp op) { + int count = 0; + auto insertValidate = [&](Value value, OpBuilder& b) { + Type valueType = value.getType(); + if (isa(getElementTypeOrSelf(valueType))) { + debug::ValidateOp::create(b, value.getLoc(), value, + "heir_debug_" + std::to_string(count++), + nullptr); + } + }; - // map ciphertext type to unique int - DenseMap typeToInt; + Block& entryBlock = op.getBody().getBlocks().front(); + OpBuilder argBuilder(&entryBlock, entryBlock.begin()); + for (auto arg : op.getArguments()) { + insertValidate(arg, argBuilder); + } - // implicit assumption the first argument is private key - auto privateKey = op.getArgument(0); + op.walk([&](Operation* walkOp) { + if (walkOp == op.getOperation() || isa(walkOp) || + walkOp->hasTrait()) + return; + OpBuilder opBuilder(walkOp->getBlock(), ++walkOp->getIterator()); + for (Value result : walkOp->getResults()) { + insertValidate(result, opBuilder); + } + }); +} - ImplicitLocOpBuilder b = ImplicitLocOpBuilder::atBlockBegin( - op.getLoc(), &op.getBody().getBlocks().front()); +LogicalResult lowerValidationOps( + func::FuncOp op, Value privateKey, int messageSize, + llvm::DenseMap, int>& typePairToInt) { + auto module = op->getParentOfType(); + Type lwePrivateKeyType = privateKey.getType(); - auto insertCall = [&](Value value) { + auto walkResult = op.walk([&](debug::ValidateOp validateOp) { + Value value = validateOp.getInput(); Type valueType = value.getType(); - // NOTE: this won't work for shaped input like tensor<2x!lwe.ciphertext> - if (auto lweCiphertextType = dyn_cast(valueType)) { - // update typeToInt - if (!typeToInt.count(valueType)) { - typeToInt[valueType] = typeToInt.size(); - } - - // get attribute associated with value + if (isa(getElementTypeOrSelf(valueType))) { + ImplicitLocOpBuilder b(validateOp.getLoc(), validateOp); SmallVector attrs; - if (auto blockArg = dyn_cast(value)) { - auto* parentOp = blockArg.getOwner()->getParentOp(); - auto funcOp = dyn_cast(parentOp); - if (funcOp) { - // always dialect attr - for (auto namedAttr : funcOp.getArgAttrs(blockArg.getArgNumber())) { - attrs.push_back(namedAttr); - } - } - } else { - auto* parentOp = value.getDefiningOp(); - for (auto namedAttr : parentOp->getDialectAttrs()) { - attrs.push_back(namedAttr); - } + // Transfer metadata from validateOp to CallOp + attrs.push_back(b.getNamedAttr("debug.name", validateOp.getNameAttr())); + if (validateOp.getMetadata()) { + attrs.push_back( + b.getNamedAttr("debug.metadata", validateOp.getMetadataAttr())); } - attrs.push_back(b.getNamedAttr( "message.size", b.getStringAttr(std::to_string(messageSize)))); - func::CallOp::create( - b, - getOrCreateExternalDebugFunc(module, lwePrivateKeyType, - lweCiphertextType, typeToInt), - ArrayRef{privateKey, value}) - ->setDialectAttrs(attrs); - } - }; - - // insert for each argument - for (auto arg : op.getArguments()) { - insertCall(arg); - } + auto debugFunc = getOrCreateExternalDebugFunc(module, lwePrivateKeyType, + valueType, typePairToInt); + auto callOp = + b.create(debugFunc, ArrayRef{privateKey, value}); + callOp->setDialectAttrs(attrs); - // insert after each HE op - op.walk([&](Operation* op) { - b.setInsertionPointAfter(op); - for (Value result : op->getResults()) { - insertCall(result); + validateOp.erase(); + } else { + validateOp.emitError( + "only LWECiphertextType is supported for debug.validate"); + return WalkResult::interrupt(); } return WalkResult::advance(); }); - return success(); -} - -LogicalResult convertFunc(func::FuncOp op, int messageSize) { - auto type = getPrivateKeyType(op); - if (failed(type)) return op.emitError("failed to get private key type"); - auto lwePrivateKeyType = type.value(); - if (failed(op.insertArgument(0, lwePrivateKeyType, nullptr, op.getLoc()))) { - return op.emitError("failed to insert private key argument"); - } - if (failed(insertExternalCall(op, lwePrivateKeyType, messageSize))) { - return op.emitError("failed to insert external call"); - } - return success(); + return walkResult.wasInterrupted() ? failure() : success(); } struct AddDebugPort : impl::AddDebugPortBase { using AddDebugPortBase::AddDebugPortBase; void runOnOperation() override { - auto funcOp = - detectEntryFunction(cast(getOperation()), entryFunction); - if (funcOp && failed(convertFunc(funcOp, messageSize))) { - funcOp->emitError("Failed to configure the crypto context for func"); + ModuleOp module = cast(getOperation()); + llvm::DenseMap, int> typePairToInt; + llvm::DenseMap funcToKeyType; + SmallVector worklist; + + SymbolTable symbolTable(module); + llvm::DenseMap> calleeToCalls; + llvm::DenseSet modifiedFuncs; + + if (failed(identifyInitialTargets(module, symbolTable, funcToKeyType, + worklist))) { signalPassFailure(); + return; + } + + if (failed(propagateKeyTypes(module, funcToKeyType, worklist, + calleeToCalls))) { + signalPassFailure(); + return; + } + + if (insertDebugAfterEveryOp) { + for (auto& [func, _] : funcToKeyType) { + insertValidationOps(func); + } + } + + if (failed(addKeyArguments(funcToKeyType, modifiedFuncs))) { + signalPassFailure(); + return; + } + + if (failed(updateCallSites(modifiedFuncs, calleeToCalls, funcToKeyType))) { + signalPassFailure(); + return; + } + + if (failed(lowerAllValidationOps(module, funcToKeyType, typePairToInt))) { + signalPassFailure(); + return; } } + + private: + /// Step 1: Identify initial targets for debug port insertion. + /// + /// This function scans the module for functions that need to be processed. + /// If an entry function is specified, it only processes that function. + /// Otherwise, it processes functions that have validation ops or if the + /// `insertDebugAfterEveryOp` flag is set, functions that have at least one + /// LWE ciphertext argument. + /// + /// \param module The module to process. + /// \param symbolTable The symbol table for the module. + /// \param funcToKeyType Output map from function to its inferred key type. + /// \param worklist Output list of functions to process. + /// \return success() if successful, failure() otherwise. + LogicalResult identifyInitialTargets( + ModuleOp module, SymbolTable& symbolTable, + llvm::DenseMap& funcToKeyType, + SmallVector& worklist) { + func::FuncOp entryFunc; + if (!entryFunction.empty()) { + entryFunc = symbolTable.lookup(entryFunction); + } + + if (entryFunc) { + auto type = getPrivateKeyType(entryFunc); + if (succeeded(type)) { + funcToKeyType[entryFunc] = *type; + worklist.push_back(entryFunc); + return success(); + } + + if (containsAnyOperations(entryFunc)) { + entryFunc.emitError( + "Cannot infer LWE private key type for entry function"); + return failure(); + } + } + + for (auto funcOp : module.getOps()) { + if (funcOp.isExternal()) continue; + + bool shouldProcess = containsAnyOperations(funcOp); + if (!shouldProcess && insertDebugAfterEveryOp) { + shouldProcess = succeeded(getPrivateKeyType(funcOp)); + } + + if (shouldProcess) { + auto type = getPrivateKeyType(funcOp); + if (failed(type)) { + return funcOp.emitError( + "Cannot infer LWE private key type for function with " + "validation ops"); + } + + funcToKeyType[funcOp] = *type; + worklist.push_back(funcOp); + } + } + return success(); + } + + /// Step 2: Propagate key types up the call graph. + /// + /// This function propagates the inferred key types from callees to callers. + /// It also populates the `calleeToCalls` map to keep track of call sites. + /// + /// \param module The module to process. + /// \param funcToKeyType Map from function to its inferred key type. + /// \param worklist List of functions to process. + /// \param calleeToCalls Output map from callee to its call sites. + /// \return success() if successful, failure() otherwise. + LogicalResult propagateKeyTypes( + ModuleOp module, llvm::DenseMap& funcToKeyType, + SmallVector& worklist, + llvm::DenseMap>& + calleeToCalls) { + while (!worklist.empty()) { + func::FuncOp currentFunc = worklist.back(); + worklist.pop_back(); + Type keyType = funcToKeyType[currentFunc]; + + auto symbolUses = SymbolTable::getSymbolUses(currentFunc, module); + if (symbolUses) { + for (auto use : *symbolUses) { + Operation* user = use.getUser(); + auto callOp = dyn_cast(user); + if (!callOp) continue; + + calleeToCalls[currentFunc].push_back(callOp); + + func::FuncOp caller = callOp->getParentOfType(); + if (!caller) continue; + + if (funcToKeyType.find(caller) == funcToKeyType.end()) { + funcToKeyType[caller] = keyType; + worklist.push_back(caller); + } else if (funcToKeyType[caller] != keyType) { + caller.emitError("Conflicting LWE private key types required"); + return failure(); + } + } + } + } + return success(); + } + + /// Step 3: Add key arguments to functions. + /// + /// This function adds the LWE private key as the first argument to functions + /// that need it and don't already have it. + /// + /// \param funcToKeyType Map from function to its inferred key type. + /// \param modifiedFuncs Output set of functions that were modified. + /// \return success() if successful, failure() otherwise. + LogicalResult addKeyArguments( + llvm::DenseMap& funcToKeyType, + llvm::DenseSet& modifiedFuncs) { + for (auto& [funcOp, keyType] : funcToKeyType) { + bool hasKey = llvm::any_of(funcOp.getArguments(), [&](Value arg) { + return arg.getType() == keyType; + }); + + if (!hasKey) { + if (failed( + funcOp.insertArgument(0, keyType, nullptr, funcOp.getLoc()))) { + funcOp.emitError("failed to insert private key argument"); + return failure(); + } + modifiedFuncs.insert(funcOp); + } + } + return success(); + } + + /// Step 4: Update call sites to pass the key argument. + /// + /// This function updates the call sites of modified functions to pass the + /// required LWE private key argument. + /// + /// \param modifiedFuncs Set of functions that were modified by adding a key + /// argument. + /// \param calleeToCalls Map from callee to its call sites. + /// \param funcToKeyType Map from function to its inferred key type. + /// \return success() if successful, failure() otherwise. + LogicalResult updateCallSites( + const llvm::DenseSet& modifiedFuncs, + const llvm::DenseMap>& + calleeToCalls, + const llvm::DenseMap& funcToKeyType) { + for (auto funcOp : modifiedFuncs) { + auto it = calleeToCalls.find(funcOp); + if (it == calleeToCalls.end()) continue; + + for (auto callOp : it->second) { + auto callerFunc = callOp->getParentOfType(); + auto keyIt = funcToKeyType.find(funcOp); + if (keyIt == funcToKeyType.end()) { + return failure(); + } + Type keyType = keyIt->second; + + auto* keyToPass = + llvm::find_if(callerFunc.getArguments(), + [&](Value arg) { return arg.getType() == keyType; }); + + if (keyToPass == callerFunc.getArguments().end()) { + callOp.emitError( + "Caller does not have the required LWE private key argument"); + return failure(); + } + + SmallVector operands; + operands.push_back(*keyToPass); + operands.append(callOp.getOperands().begin(), + callOp.getOperands().end()); + + OpBuilder b(callOp); + auto newCall = + func::CallOp::create(b, callOp.getLoc(), funcOp, operands); + newCall->setDialectAttrs(callOp->getDialectAttrs()); + for (unsigned i = 0; i < callOp.getNumResults(); ++i) { + callOp.getResult(i).replaceAllUsesWith(newCall.getResult(i)); + } + callOp.erase(); + } + } + return success(); + } + + /// Step 5: Lower all validation ops to external calls. + /// + /// This function lowers all `debug.validate` ops in the module to calls to + /// external debug functions. + /// + /// \param module The module to process. + /// \param funcToKeyType Map from function to its inferred key type. + /// \param typePairToInt Map to track generated debug function names. + /// \return success() if successful, failure() otherwise. + LogicalResult lowerAllValidationOps( + ModuleOp module, const llvm::DenseMap& funcToKeyType, + llvm::DenseMap, int>& typePairToInt) { + for (auto funcOp : module.getOps()) { + if (funcOp.isExternal()) continue; + + Type keyType; + Value privateKey; + auto it = funcToKeyType.find(funcOp); + if (it != funcToKeyType.end()) { + keyType = it->second; + privateKey = *llvm::find_if(funcOp.getArguments(), [&](Value arg) { + return arg.getType() == keyType; + }); + } else { + for (auto arg : funcOp.getArguments()) { + if (isa(arg.getType())) { + keyType = arg.getType(); + privateKey = arg; + break; + } + } + } + + if (privateKey) { + if (failed(lowerValidationOps(funcOp, privateKey, messageSize, + typePairToInt))) { + funcOp.emitError("failed to lower validation ops"); + return failure(); + } + } else if (containsAnyOperations(funcOp)) { + funcOp.emitError( + "validation operations cannot be lowered without a private key"); + return failure(); + } + } + return success(); + } }; + } // namespace lwe } // namespace heir } // namespace mlir diff --git a/lib/Dialect/LWE/Transforms/BUILD b/lib/Dialect/LWE/Transforms/BUILD index 24d2bf1a34..999d1e16c3 100644 --- a/lib/Dialect/LWE/Transforms/BUILD +++ b/lib/Dialect/LWE/Transforms/BUILD @@ -27,15 +27,16 @@ cc_library( ], deps = [ ":pass_inc_gen", + "@heir//lib/Dialect:FuncUtils", + "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", - "@heir//lib/Dialect/TensorExt/IR:Dialect", - "@heir//lib/Utils:TransformUtils", + "@heir//lib/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", ], ) diff --git a/lib/Dialect/LWE/Transforms/Passes.td b/lib/Dialect/LWE/Transforms/Passes.td index 26f19be9f9..b048394efe 100644 --- a/lib/Dialect/LWE/Transforms/Passes.td +++ b/lib/Dialect/LWE/Transforms/Passes.td @@ -12,6 +12,12 @@ def AddDebugPort : Pass<"lwe-add-debug-port"> { function. The debug ports are declarations and user should provide functions with the same name in their code. + If the option `insert-debug-after-every-op` is set to true, it will insert a `debug.validate` + op after every homomorphic operation. + + Regardless of the `insert-debug-after-every-op` option, this pass will lower all + `debug.validate` ops it encounters to function calls. + For example, if the function is called "foo", the secret key is added to its arguments, and the debug port is called after each homomorphic operation: ```mlir @@ -29,13 +35,18 @@ def AddDebugPort : Pass<"lwe-add-debug-port"> { } ``` }]; - let dependentDialects = ["mlir::heir::lwe::LWEDialect"]; + let dependentDialects = [ + "mlir::heir::lwe::LWEDialect", + "mlir::heir::debug::DebugDialect" + ]; let options = [ Option<"entryFunction", "entry-function", "std::string", /*default=*/"", "Default entry function " "name of entry function.">, Option<"messageSize", "message-size", "int", /*default=*/"1", "The size of the message in the ciphertext.">, + Option<"insertDebugAfterEveryOp", "insert-debug-after-every-op", "bool", + /*default=*/"false", "Whether to add debug ports after every op"> ]; } diff --git a/lib/Dialect/Secret/Conversions/BUILD b/lib/Dialect/Secret/Conversions/BUILD index e2e0d76458..2442bddd8f 100644 --- a/lib/Dialect/Secret/Conversions/BUILD +++ b/lib/Dialect/Secret/Conversions/BUILD @@ -11,6 +11,7 @@ cc_library( hdrs = ["Patterns.h"], deps = [ "@heir//lib/Dialect:ModuleAttributes", + "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:Dialect", diff --git a/lib/Dialect/Secret/Conversions/Patterns.cpp b/lib/Dialect/Secret/Conversions/Patterns.cpp index 82c6c10e9c..f16371c2e3 100644 --- a/lib/Dialect/Secret/Conversions/Patterns.cpp +++ b/lib/Dialect/Secret/Conversions/Patterns.cpp @@ -3,6 +3,8 @@ #include #include +#include "lib/Dialect/Debug/IR/DebugDialect.h" +#include "lib/Dialect/Debug/IR/DebugOps.h" #include "lib/Dialect/LWE/IR/LWEAttributes.h" #include "lib/Dialect/LWE/IR/LWEDialect.h" #include "lib/Dialect/LWE/IR/LWEOps.h" @@ -502,6 +504,20 @@ LogicalResult ConvertEmpty::matchAndRewrite( return success(); } +// This only needs a special pattern because it has attributes that aren't +// copied over by the base SecretGenericOpConversion. +FailureOr ConvertDebugValidate::matchAndRewriteInner( + secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, + ArrayRef attributes, + ContextAwareConversionPatternRewriter& rewriter) const { + debug::ValidateOp innerOp = + cast(op.getBody()->getOperations().front()); + debug::ValidateOp newOp = debug::ValidateOp::create( + rewriter, op.getLoc(), outputTypes, inputs, innerOp->getAttrs()); + rewriter.replaceOp(op, newOp); + return newOp.getOperation(); +} + bool hasSecretOperandsOrResults(Operation* op) { return llvm::any_of(op->getOperands(), [](Value operand) { @@ -527,24 +543,23 @@ void addSecretToSchemeDefaultConversionTargetsAndPatterns( target.markUnknownOpDynamicallyLegal( [&](Operation* op) { return !hasSecretOperandsOrResults(op); }); - patterns.add, - SecretGenericOpIdentityConversion, - SecretGenericOpIdentityConversion, - SecretGenericOpIdentityConversion, - SecretGenericOpIdentityConversion, - SecretGenericOpIdentityConversion, - SecretGenericOpConversion, - SecretGenericFuncCallConversion, ConvertExtractSlice, - ConvertInsertSlice, ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware, - ConvertAnyContextAware>(typeConverter, - patterns.getContext()); + patterns.add< + ConvertAnyContextAware, + ConvertAnyContextAware, + ConvertAnyContextAware, + ConvertAnyContextAware, ConvertAnyContextAware, + ConvertAnyContextAware, ConvertAnyContextAware, + ConvertAnyContextAware, + ConvertAnyContextAware, ConvertDebugValidate, + ConvertExtractSlice, ConvertInsertSlice, SecretGenericFuncCallConversion, + SecretGenericOpConversion, + SecretGenericOpIdentityConversion, + SecretGenericOpIdentityConversion, + SecretGenericOpIdentityConversion, + SecretGenericOpIdentityConversion, + SecretGenericOpIdentityConversion, + SecretGenericOpIdentityConversion>( + typeConverter, patterns.getContext()); addStructuralConversionPatterns(typeConverter, patterns, target); } diff --git a/lib/Dialect/Secret/Conversions/Patterns.h b/lib/Dialect/Secret/Conversions/Patterns.h index f71b9faa25..b21d0c3aab 100644 --- a/lib/Dialect/Secret/Conversions/Patterns.h +++ b/lib/Dialect/Secret/Conversions/Patterns.h @@ -1,6 +1,7 @@ #ifndef LIB_DIALECT_SECRET_CONVERSIONS_PATTERNS_H_ #define LIB_DIALECT_SECRET_CONVERSIONS_PATTERNS_H_ +#include "lib/Dialect/Debug/IR/DebugOps.h" #include "lib/Dialect/Mgmt/IR/MgmtOps.h" #include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "lib/Dialect/Secret/IR/SecretOps.h" @@ -90,6 +91,17 @@ struct ConvertInsertSlice ContextAwareConversionPatternRewriter& rewriter) const override; }; +struct ConvertDebugValidate + : public SecretGenericOpConversion { + using SecretGenericOpConversion::SecretGenericOpConversion; + + FailureOr matchAndRewriteInner( + secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, + ArrayRef attributes, + ContextAwareConversionPatternRewriter& rewriter) const override; +}; + // An empty ciphertext-semantic tensor can be used as the initializer of a // reduction. In this case, there is no containing secret.generic op, and we // anchor on the subsequent `mgmt::InitOp` to determine how to convert it to a diff --git a/lib/Dialect/Secret/Transforms/AddDebugPort.cpp b/lib/Dialect/Secret/Transforms/AddDebugPort.cpp index fae2f10242..03a815aaee 100644 --- a/lib/Dialect/Secret/Transforms/AddDebugPort.cpp +++ b/lib/Dialect/Secret/Transforms/AddDebugPort.cpp @@ -1,16 +1,20 @@ #include "lib/Dialect/Secret/Transforms/AddDebugPort.h" +#include #include #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "lib/Dialect/Debug/IR/DebugOps.h" #include "lib/Dialect/FuncUtils.h" #include "lib/Dialect/Secret/IR/SecretOps.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project @@ -46,64 +50,81 @@ func::FuncOp getOrCreateExternalDebugFunc(ModuleOp module, Type valueType) { return funcOp; } -LogicalResult insertExternalCall(secret::GenericOp op, DataFlowSolver& solver) { - auto module = op->getParentOfType(); - - ImplicitLocOpBuilder b = - ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody()); - - auto insertCall = [&](Value value) { - Type valueType = value.getType(); - - func::CallOp::create(b, getOrCreateExternalDebugFunc(module, valueType), - ArrayRef{value}); +void insertValidationOps(secret::GenericOp op, DataFlowSolver& solver) { + int count = 0; + auto insertValidate = [&](Value value, OpBuilder& b) { + if (isSecret(value, &solver)) { + debug::ValidateOp::create(b, value.getLoc(), value, + "heir_debug_" + std::to_string(count++), + nullptr); + } }; - // insert for each argument - for (auto arg : op.getBody()->getArguments()) { - if (!isSecret(arg, &solver)) { - continue; - } - insertCall(arg); + Block* body = op.getBody(); + OpBuilder argBuilder(body, body->begin()); + for (auto arg : body->getArguments()) { + insertValidate(arg, argBuilder); } - // insert after each op - op.walk([&](Operation* op) { - if (mlir::isa(op)) { + op.walk([&](Operation* walkOp) { + if (walkOp == op.getOperation() || mlir::isa(walkOp) || + walkOp->hasTrait()) { return; } - b.setInsertionPointAfter(op); - for (Value result : op->getResults()) { - if (!isSecret(result, &solver)) { - continue; - } - insertCall(result); + OpBuilder opBuilder(walkOp->getBlock(), ++walkOp->getIterator()); + for (Value result : walkOp->getResults()) { + insertValidate(result, opBuilder); + } + }); +} + +void lowerValidationOps(secret::GenericOp op) { + auto module = op->getParentOfType(); + op.walk([&](debug::ValidateOp validateOp) { + Value value = validateOp.getInput(); + ImplicitLocOpBuilder b(validateOp.getLoc(), validateOp); + + auto callOp = b.create( + getOrCreateExternalDebugFunc(module, value.getType()), + ArrayRef{value}); + + // Transfer attributes + callOp->setAttr("debug.name", validateOp.getNameAttr()); + if (validateOp.getMetadata()) { + callOp->setAttr("debug.metadata", validateOp.getMetadataAttr()); } + + validateOp.erase(); }); - return success(); } struct AddDebugPort : impl::SecretAddDebugPortBase { using SecretAddDebugPortBase::SecretAddDebugPortBase; void runOnOperation() override { - DataFlowSolver solver; - dataflow::loadBaselineAnalyses(solver); - solver.load(); - - auto result = solver.initializeAndRun(getOperation()); - if (failed(result)) { - getOperation()->emitOpError() << "Failed to run the analysis.\n"; - signalPassFailure(); - return; + std::unique_ptr solver; + + // No need to do a secretness analysis if we're not inserting new + // debug.validate ops + if (insertDebugAfterEveryOp) { + solver = std::make_unique(); + dataflow::loadBaselineAnalyses(*solver); + solver->load(); + + auto result = solver->initializeAndRun(getOperation()); + if (failed(result)) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } } getOperation()->walk([&](secret::GenericOp genericOp) { - if (failed(insertExternalCall(genericOp, solver))) { - genericOp->emitError("Failed to add debug port for genericOp"); - signalPassFailure(); + if (insertDebugAfterEveryOp) { + insertValidationOps(genericOp, *solver); } + lowerValidationOps(genericOp); }); } }; diff --git a/lib/Dialect/Secret/Transforms/BUILD b/lib/Dialect/Secret/Transforms/BUILD index 1093ae1d38..d0b625cbe7 100644 --- a/lib/Dialect/Secret/Transforms/BUILD +++ b/lib/Dialect/Secret/Transforms/BUILD @@ -178,7 +178,9 @@ cc_library( ":pass_inc_gen", "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect:FuncUtils", + "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/Secret/Transforms/Passes.td b/lib/Dialect/Secret/Transforms/Passes.td index 15fa24be71..6b7cfe8487 100644 --- a/lib/Dialect/Secret/Transforms/Passes.td +++ b/lib/Dialect/Secret/Transforms/Passes.td @@ -103,13 +103,23 @@ def SecretExtractGenericBody : Pass<"secret-extract-generic-body"> { def SecretAddDebugPort : Pass<"secret-add-debug-port"> { let summary = "Add debug port to secret-arithmetic ops"; let description = [{ - This pass adds debug ports to secret-arithmetic ops in the IR, namely operations - wrapped by secret.generic. The debug ports are prefixed with "__heir_debug" and - are invoked after each operation in the generic body. The debug ports are - declarations and user should provide functions with the same name in their code. + This pass adds debug ports to secret-arithmetic ops in the IR, namely + operations wrapped by `secret.generic`. The debug ports are prefixed with + `"__heir_debug"` and are invoked after each operation in the generic body. + The debug ports are declarations and the user must define functions with the + same name in their application code. + + If the option `insert-debug-after-every-op` is set to true, it will insert a + call after every operation in the generic body. + + Regardless of the `insert-debug-after-every-op` option, this pass lowers all + `debug.validate` ops it encounters in the generic body to function calls. + In this way, the user can provide specific checkpoints in a program to invoke + debug code. For example, if the function is called "foo", the debug port is called after each homomorphic operation: + ```mlir // declaration of external debug function func.func private @__heir_debug_tensor_8xi16_(tensor<8xi16>) @@ -124,7 +134,18 @@ def SecretAddDebugPort : Pass<"secret-add-debug-port"> { } } ``` + + Because this pass is agnostic of the cryptosystem backend, it does not insert + secret key material into the debug handler call. See `lwe-add-debug-port` for + cryptosystem-aware analogues of this pass. This pass must remain, in + particular, to support the plaintext pipeline. }]; + + let options = [ + Option<"insertDebugAfterEveryOp", "insert-debug-after-every-op", "bool", + /*default=*/"false", "Whether to add debug calls after every op">, + ]; + let dependentDialects = ["mlir::heir::debug::DebugDialect"]; } def SecretImportExecutionResult : Pass<"secret-import-execution-result"> { diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index c80a9196e8..5c36afa721 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -385,7 +385,20 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::CallOp op) { << definingOp->getName() << "\";\n"; } // Use AsmPrinter to print Value - os << debugAttrMapName << R"(["asm.result_ssa_format"] = ")" << ciphertext + std::string ssaFormat; + llvm::raw_string_ostream ss(ssaFormat); + ss << ciphertext; + std::string escaped; + for (char c : ssaFormat) { + if (c == '\n') { + escaped += "\\n"; + } else if (c == '"') { + escaped += "\\\""; + } else { + escaped += c; + } + } + os << debugAttrMapName << R"(["asm.result_ssa_format"] = ")" << escaped << "\";\n"; } diff --git a/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.cpp index 33aa3bff63..88d77a9637 100644 --- a/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.cpp @@ -10,6 +10,7 @@ #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project @@ -62,9 +63,37 @@ LogicalResult OpenFhePkePybindEmitter::printOperation(ModuleOp moduleOp) { } LogicalResult OpenFhePkePybindEmitter::printOperation(func::FuncOp funcOp) { - os << llvm::formatv(kPybindFunctionTemplate.data(), - canonicalizeDebugPort(funcOp.getName())) - << "\n"; + llvm::StringRef funcName = canonicalizeDebugPort(funcOp.getName()); + if (funcName == "__heir_debug") { + bool isVector = false; + if (funcOp.getNumArguments() >= 3) { + mlir::Type thirdArgType = funcOp.getArgumentTypes()[2]; + if (llvm::isa(thirdArgType) || + llvm::isa(thirdArgType)) { + isVector = true; + } + } + + if (isVector) { + if (!boundHeirDebugVector_) { + os << "m.def(\"__heir_debug\", py::overload_cast, const " + "std::map&>(&__heir_debug), " + "py::call_guard());\n"; + boundHeirDebugVector_ = true; + } + } else { + if (!boundHeirDebugSingle_) { + os << "m.def(\"__heir_debug\", py::overload_cast&>(&__heir_debug), " + "py::call_guard());\n"; + boundHeirDebugSingle_ = true; + } + } + } else { + os << llvm::formatv(kPybindFunctionTemplate.data(), funcName) << "\n"; + } return success(); } diff --git a/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.h b/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.h index 27df34e47a..7ff4b90efe 100644 --- a/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkePybindEmitter.h @@ -40,6 +40,9 @@ class OpenFhePkePybindEmitter { /// Output stream to emit to. raw_indented_ostream os; + bool boundHeirDebugSingle_ = false; + bool boundHeirDebugVector_ = false; + // Functions for printing individual ops LogicalResult printOperation(::mlir::ModuleOp op); LogicalResult printOperation(::mlir::func::FuncOp op); diff --git a/tests/Dialect/Debug/Transforms/lower_validate_lwe.mlir b/tests/Dialect/Debug/Transforms/lower_validate_lwe.mlir new file mode 100644 index 0000000000..2a3d7ee541 --- /dev/null +++ b/tests/Dialect/Debug/Transforms/lower_validate_lwe.mlir @@ -0,0 +1,20 @@ +// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port %s | FileCheck %s + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> +#modulus_chain = #lwe.modulus_chain, current = 0> +#plaintext_space = #lwe.plaintext_space +#ciphertext_space = #lwe.ciphertext_space + +!ct_ty = !lwe.lwe_ciphertext + +module { + // CHECK: func.func private @__heir_debug_0 + func.func @test_lower_validate_lwe(%arg0: !ct_ty) -> !ct_ty { + // CHECK: func.func @test_lower_validate_lwe + // CHECK: call @__heir_debug_0 + debug.validate %arg0 {name = "lwe_val1", metadata = "lwe_meta1"} : !ct_ty + return %arg0 : !ct_ty + } +} diff --git a/tests/Dialect/Debug/Transforms/lower_validate_secret.mlir b/tests/Dialect/Debug/Transforms/lower_validate_secret.mlir new file mode 100644 index 0000000000..b7fe825640 --- /dev/null +++ b/tests/Dialect/Debug/Transforms/lower_validate_secret.mlir @@ -0,0 +1,16 @@ +// RUN: heir-opt --secret-add-debug-port %s | FileCheck %s + +module { + func.func @test_lower_validate(%arg0: !secret.secret) -> !secret.secret { + %0 = secret.generic(%arg0: !secret.secret) { + ^body(%arg1: i32): + debug.validate %arg1 {name = "val1"} : i32 + secret.yield %arg1 : i32 + } -> !secret.secret + return %0 : !secret.secret + } +} + +// CHECK: func.func private @__heir_debug_i32(i32) +// CHECK: func.func @test_lower_validate +// CHECK: call @__heir_debug_i32({{.*}}) {debug.name = "val1"} diff --git a/tests/Dialect/LWE/Transforms/add_debug_port.mlir b/tests/Dialect/LWE/Transforms/add_debug_port.mlir index edd2fbbbb6..77c2185510 100644 --- a/tests/Dialect/LWE/Transforms/add_debug_port.mlir +++ b/tests/Dialect/LWE/Transforms/add_debug_port.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port %s | FileCheck %s +// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port=insert-debug-after-every-op=true %s | FileCheck %s !Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> !Z65537_i64_ = !mod_arith.int<65537 : i64> diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_calls.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_calls.mlir new file mode 100644 index 0000000000..99a1fd02fb --- /dev/null +++ b/tests/Dialect/LWE/Transforms/add_debug_port_calls.mlir @@ -0,0 +1,37 @@ +// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port %s | FileCheck %s + +!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> +!Z65537_i64_ = !mod_arith.int<65537 : i64> + +!rns_L0_ = !rns.rns + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#ring_rns_L0_1_x32_ = #polynomial.ring> + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> + +#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> + +#plaintext_space = #lwe.plaintext_space + +#ciphertext_space_L0_ = #lwe.ciphertext_space + +!ty = !lwe.lwe_ciphertext + +module { + // CHECK: func @callee + // CHECK-SAME: (%[[SK:[^:]*]]: {{.*}}, %[[ARG:[^:]*]]: !lwe.lwe_ciphertext<{{.*}}>) + func.func @callee(%arg0: !ty) { + debug.validate %arg0 {name = "callee_debug"} : !ty + return + } + + // CHECK: func @caller + // CHECK-SAME: (%[[SK2:[^:]*]]: {{.*}}, %[[ARG2:[^:]*]]: !lwe.lwe_ciphertext<{{.*}}>) + func.func @caller(%arg0: !ty) { + // CHECK: call @callee(%[[SK2]], %[[ARG2]]) + func.call @callee(%arg0) : (!ty) -> () + return + } +} diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_entry.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_entry.mlir new file mode 100644 index 0000000000..342483e31e --- /dev/null +++ b/tests/Dialect/LWE/Transforms/add_debug_port_entry.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port=entry-function=foo %s | FileCheck %s + +!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> +!Z65537_i64_ = !mod_arith.int<65537 : i64> + +!rns_L0_ = !rns.rns + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#ring_rns_L0_1_x32_ = #polynomial.ring> + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> + +#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> + +#plaintext_space = #lwe.plaintext_space + +#ciphertext_space_L0_ = #lwe.ciphertext_space + +!ty = !lwe.lwe_ciphertext + +func.func @foo(%arg0: !ty) -> !ty { + debug.validate %arg0 {name = "v1"} : !ty + return %arg0 : !ty +} + +func.func @bar(%arg0: !ty) -> !ty { + return %arg0 : !ty +} + +// CHECK: func.func @foo(%{{.*}}: !lwe.lwe_secret_key{{.*}}, %{{.*}}: !lwe.lwe_ciphertext{{.*}}) +// CHECK: func.func @bar(%{{.*}}: !lwe.lwe_ciphertext{{.*}}) diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_error.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_error.mlir new file mode 100644 index 0000000000..75a2fca277 --- /dev/null +++ b/tests/Dialect/LWE/Transforms/add_debug_port_error.mlir @@ -0,0 +1,30 @@ +// RUN: heir-opt --verify-diagnostics --lwe-add-debug-port=entry-function=foo %s + +!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> +!Z65537_i64_ = !mod_arith.int<65537 : i64> + +!rns_L0_ = !rns.rns + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#ring_rns_L0_1_x32_ = #polynomial.ring> + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> + +#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> + +#plaintext_space = #lwe.plaintext_space + +#ciphertext_space_L0_ = #lwe.ciphertext_space + +!ty = !lwe.lwe_ciphertext + +func.func @foo(%arg0: !ty) -> !ty { + return %arg0 : !ty +} + +// expected-error@below {{validation operations cannot be lowered without a private key}} +func.func @bar(%arg0: i32) { + debug.validate %arg0 {name = "v1"} : i32 + return +} diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_message_size_attr.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_message_size_attr.mlir index 83ff40c1de..946bbe6478 100644 --- a/tests/Dialect/LWE/Transforms/add_debug_port_message_size_attr.mlir +++ b/tests/Dialect/LWE/Transforms/add_debug_port_message_size_attr.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --lwe-add-debug-port %s | FileCheck %s +// RUN: heir-opt --lwe-add-debug-port=insert-debug-after-every-op=true %s | FileCheck %s !Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64> !Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> @@ -17,7 +17,7 @@ module attributes {scheme.bgv} { // CHECK: lwe.radd %ct_1 = lwe.radd %ct, %ct_0 : (!ct_L1_, !ct_L1_) -> !ct_L1_ // CHECK: call @__heir_debug - // CHECK-SAME: {message.size = "1"} + // CHECK-SAME: {debug.name = "heir_debug_{{.*}}", message.size = "1"} return } } diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_warning.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_warning.mlir new file mode 100644 index 0000000000..103d9da8f3 --- /dev/null +++ b/tests/Dialect/LWE/Transforms/add_debug_port_warning.mlir @@ -0,0 +1,29 @@ +// RUN: heir-opt --verify-diagnostics --lwe-add-debug-port %s + +!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> +!Z65537_i64_ = !mod_arith.int<65537 : i64> + +!rns_L0_ = !rns.rns + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#ring_rns_L0_1_x32_ = #polynomial.ring> + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#full_crt_packing_encoding2 = #lwe.full_crt_packing_encoding +#key = #lwe.key<> + +#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> + +#plaintext_space = #lwe.plaintext_space +#plaintext_space2 = #lwe.plaintext_space + +#ciphertext_space_L0_ = #lwe.ciphertext_space + +!ty1 = !lwe.lwe_ciphertext +!ty2 = !lwe.lwe_ciphertext + +// expected-warning@below {{Conflicting ciphertext types found among function arguments}} +func.func @conflicting_types(%arg0: !ty1, %arg1: !ty2) -> !ty1 { + debug.validate %arg0 {name = "v1"} : !ty1 + return %arg0 : !ty1 +} diff --git a/tests/Emitter/Openfhe/emit_pybind.mlir b/tests/Emitter/Openfhe/emit_pybind.mlir index 9994416477..6ba9f3c18f 100644 --- a/tests/Emitter/Openfhe/emit_pybind.mlir +++ b/tests/Emitter/Openfhe/emit_pybind.mlir @@ -31,6 +31,9 @@ // CHECK: m.def("simple_sum", &simple_sum, py::call_guard()); // CHECK: m.def("simple_sum__encrypt", &simple_sum__encrypt, py::call_guard()); // CHECK: m.def("simple_sum__decrypt", &simple_sum__decrypt, py::call_guard()); +// CHECK: m.def("__heir_debug", py::overload_cast&>(&__heir_debug), py::call_guard()); +// CHECK: m.def("__heir_debug", py::overload_cast, const std::map&>(&__heir_debug), py::call_guard()); +// CHECK-NOT: m.def("__heir_debug" // CHECK: } !cc = !openfhe.crypto_context @@ -65,3 +68,15 @@ func.func @simple_sum__decrypt(%arg0: !openfhe.crypto_context, %arg1: !ct, %arg2 %1 = openfhe.decode %0 : !pt -> i16 return %1 : i16 } + +func.func @__heir_debug(%arg0: !openfhe.crypto_context, %arg1: !openfhe.private_key, %arg2: !openfhe.ciphertext) { + return +} + +func.func @__heir_debug_vec(%arg0: !openfhe.crypto_context, %arg1: !openfhe.private_key, %arg2: tensor<8x!openfhe.ciphertext>) { + return +} + +func.func @__heir_debug_dup(%arg0: !openfhe.crypto_context, %arg1: !openfhe.private_key, %arg2: !openfhe.ciphertext) { + return +} diff --git a/tests/Examples/lattigo/bfv/bfv_debug.go b/tests/Examples/lattigo/bfv/bfv_debug.go index a25fbfe58e..7b2134b185 100644 --- a/tests/Examples/lattigo/bfv/bfv_debug.go +++ b/tests/Examples/lattigo/bfv/bfv_debug.go @@ -1,4 +1,5 @@ // Debug func implementation for testing +// Lattigo v6 implements BFV within the BGV package as a scale-invariant variant. package main import ( @@ -9,7 +10,21 @@ import ( "github.com/tuneinsight/lattigo/v6/schemes/bgv" ) -func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext, debugAttrMap map[string]string) { +func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ctObj any, debugAttrMap map[string]string) { + var ct *rlwe.Ciphertext + switch v := ctObj.(type) { + case *rlwe.Ciphertext: + ct = v + case []*rlwe.Ciphertext: + if len(v) == 0 { + fmt.Println("Empty ciphertext slice") + return + } + fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) + ct = v[0] + default: + panic(fmt.Sprintf("unexpected type %T", ctObj)) + } // print op isBlockArgument := debugAttrMap["asm.is_block_arg"] if isBlockArgument == "1" { @@ -19,10 +34,18 @@ func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.E } // print the decryption result - messageSizeStr := debugAttrMap["message.size"] - messageSize, err := strconv.Atoi(messageSizeStr) - if err != nil { - panic(err) + messageSizeStr, ok := debugAttrMap["message.size"] + var messageSize int + var err error + if !ok || messageSizeStr == "" { + fmt.Println("Warning: message.size missing, defaulting to 1") + messageSize = 1 + } else { + messageSize, err = strconv.Atoi(messageSizeStr) + if err != nil { + fmt.Printf("Warning: invalid message.size %s, defaulting to 1\n", messageSizeStr) + messageSize = 1 + } } value := make([]int64, messageSize) pt := decryptor.DecryptNew(ct) diff --git a/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD b/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD index 6db6d09e62..4691ead83e 100644 --- a/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD +++ b/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD @@ -12,6 +12,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=ciphertext-degree=1024 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD index 43a6576240..897bb9129b 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD @@ -13,6 +13,7 @@ heir_lattigo_lib( "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 \ annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_dep_16.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD index 927999e94a..1d88da97fb 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD @@ -13,6 +13,7 @@ heir_lattigo_lib( "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 \ annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_dep_8.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD index f1fc0b34a9..42d04ef85d 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD @@ -12,6 +12,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_indep_32.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD index fbc8fbf30d..6e918c9fde 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD @@ -13,6 +13,7 @@ heir_lattigo_lib( "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 \ annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_indep_8.mlir", diff --git a/tests/Examples/lattigo/bgv/bgv_debug.go b/tests/Examples/lattigo/bgv/bgv_debug.go index a238a85de1..2c8c76d98b 100644 --- a/tests/Examples/lattigo/bgv/bgv_debug.go +++ b/tests/Examples/lattigo/bgv/bgv_debug.go @@ -9,7 +9,21 @@ import ( "github.com/tuneinsight/lattigo/v6/schemes/bgv" ) -func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext, debugAttrMap map[string]string) { +func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ctObj any, debugAttrMap map[string]string) { + var ct *rlwe.Ciphertext + switch v := ctObj.(type) { + case *rlwe.Ciphertext: + ct = v + case []*rlwe.Ciphertext: + if len(v) == 0 { + fmt.Println("Empty ciphertext slice") + return + } + fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) + ct = v[0] + default: + panic(fmt.Sprintf("unexpected type %T", ctObj)) + } // print op isBlockArgument := debugAttrMap["asm.is_block_arg"] if isBlockArgument == "1" { @@ -19,10 +33,18 @@ func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.E } // print the decryption result - messageSizeStr := debugAttrMap["message.size"] - messageSize, err := strconv.Atoi(messageSizeStr) - if err != nil { - panic(err) + messageSizeStr, ok := debugAttrMap["message.size"] + var messageSize int + var err error + if !ok || messageSizeStr == "" { + fmt.Println("Warning: message.size missing, defaulting to 1") + messageSize = 1 + } else { + messageSize, err = strconv.Atoi(messageSizeStr) + if err != nil { + fmt.Printf("Warning: invalid message.size %s, defaulting to 1\n", messageSizeStr) + messageSize = 1 + } } value := make([]int64, messageSize) pt := decryptor.DecryptNew(ct) diff --git a/tests/Examples/lattigo/bgv/cross_level/BUILD b/tests/Examples/lattigo/bgv/cross_level/BUILD index b6e381edf1..c27c00a597 100644 --- a/tests/Examples/lattigo/bgv/cross_level/BUILD +++ b/tests/Examples/lattigo/bgv/cross_level/BUILD @@ -19,6 +19,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bgv", "--mlir-to-bgv=ciphertext-degree=1024", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "cross_level.mlir", diff --git a/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go b/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go index 1b745bfdf3..41748c14a5 100644 --- a/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go +++ b/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go @@ -9,7 +9,21 @@ import ( "github.com/tuneinsight/lattigo/v6/schemes/bgv" ) -func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext, debugAttrMap map[string]string) { +func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.Encoder, decryptor *rlwe.Decryptor, ctObj any, debugAttrMap map[string]string) { + var ct *rlwe.Ciphertext + switch v := ctObj.(type) { + case *rlwe.Ciphertext: + ct = v + case []*rlwe.Ciphertext: + if len(v) == 0 { + fmt.Println("Empty ciphertext slice") + return + } + fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) + ct = v[0] + default: + panic(fmt.Sprintf("unexpected type %T", ctObj)) + } // print op isBlockArgument := debugAttrMap["asm.is_block_arg"] if isBlockArgument == "1" { @@ -19,10 +33,18 @@ func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.E } // print the decryption result - messageSizeStr := debugAttrMap["message.size"] - messageSize, err := strconv.Atoi(messageSizeStr) - if err != nil { - panic(err) + messageSizeStr, ok := debugAttrMap["message.size"] + var messageSize int + var err error + if !ok || messageSizeStr == "" { + fmt.Println("Warning: message.size missing, defaulting to 1") + messageSize = 1 + } else { + messageSize, err = strconv.Atoi(messageSizeStr) + if err != nil { + fmt.Printf("Warning: invalid message.size %s, defaulting to 1\n", messageSizeStr) + messageSize = 1 + } } value := make([]int64, messageSize) pt := decryptor.DecryptNew(ct) diff --git a/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD b/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD index 3af96a3c77..4efb60c2b6 100644 --- a/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD +++ b/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD @@ -12,6 +12,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bgv", "--mlir-to-bgv=ciphertext-degree=8192 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", diff --git a/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD b/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD index 82220d3920..c3c3d7fbf6 100644 --- a/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD +++ b/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD @@ -12,6 +12,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bgv", "--mlir-to-bgv=ciphertext-degree=8192 noise-model=bgv-noise-mono annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", diff --git a/tests/Examples/lattigo/ckks/ckks_debug.go b/tests/Examples/lattigo/ckks/ckks_debug.go index 90c6eb102d..1aa5842f26 100644 --- a/tests/Examples/lattigo/ckks/ckks_debug.go +++ b/tests/Examples/lattigo/ckks/ckks_debug.go @@ -11,7 +11,22 @@ import ( "github.com/tuneinsight/lattigo/v6/schemes/ckks" ) -func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckks.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext, debugAttrMap map[string]string) { +func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckks.Encoder, decryptor *rlwe.Decryptor, ctObj any, debugAttrMap map[string]string) { + var ct *rlwe.Ciphertext + switch v := ctObj.(type) { + case *rlwe.Ciphertext: + ct = v + case []*rlwe.Ciphertext: + if len(v) == 0 { + fmt.Println("Empty ciphertext slice") + return + } + fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) + ct = v[0] + default: + panic(fmt.Sprintf("unexpected type %T", ctObj)) + } + // print op isBlockArgument := debugAttrMap["asm.is_block_arg"] if isBlockArgument == "1" { @@ -21,10 +36,18 @@ func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckk } // print the decryption result - messageSizeStr := debugAttrMap["message.size"] - messageSize, err := strconv.Atoi(messageSizeStr) - if err != nil { - panic(err) + messageSizeStr, ok := debugAttrMap["message.size"] + var messageSize int + var err error + if !ok || messageSizeStr == "" { + fmt.Println("Warning: message.size missing, defaulting to 1") + messageSize = 1 + } else { + messageSize, err = strconv.Atoi(messageSizeStr) + if err != nil { + fmt.Printf("Warning: invalid message.size %s, defaulting to 1\n", messageSizeStr) + messageSize = 1 + } } value := make([]float64, messageSize) pt := decryptor.DecryptNew(ct) diff --git a/tests/Examples/lattigo/ckks/cross_level/BUILD b/tests/Examples/lattigo/ckks/cross_level/BUILD index cc1badbf5e..d6151b4737 100644 --- a/tests/Examples/lattigo/ckks/cross_level/BUILD +++ b/tests/Examples/lattigo/ckks/cross_level/BUILD @@ -19,6 +19,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", "--mlir-to-ckks=ciphertext-degree=4 modulus-switch-before-first-mul=true first-mod-bits=59 scaling-mod-bits=45", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "cross_level.mlir", diff --git a/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go b/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go index 6ad5020153..325c044ce5 100644 --- a/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go +++ b/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go @@ -10,7 +10,21 @@ import ( "github.com/tuneinsight/lattigo/v6/schemes/ckks" ) -func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckks.Encoder, decryptor *rlwe.Decryptor, ct *rlwe.Ciphertext, debugAttrMap map[string]string) { +func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckks.Encoder, decryptor *rlwe.Decryptor, ctObj any, debugAttrMap map[string]string) { + var ct *rlwe.Ciphertext + switch v := ctObj.(type) { + case *rlwe.Ciphertext: + ct = v + case []*rlwe.Ciphertext: + if len(v) == 0 { + fmt.Println("Empty ciphertext slice") + return + } + fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) + ct = v[0] + default: + panic(fmt.Sprintf("unexpected type %T", ctObj)) + } // print op isBlockArgument := debugAttrMap["asm.is_block_arg"] if isBlockArgument == "1" { @@ -20,10 +34,18 @@ func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckk } // print the decryption result - messageSizeStr := debugAttrMap["message.size"] - messageSize, err := strconv.Atoi(messageSizeStr) - if err != nil { - panic(err) + messageSizeStr, ok := debugAttrMap["message.size"] + var messageSize int + var err error + if !ok || messageSizeStr == "" { + fmt.Println("Warning: message.size missing, defaulting to 1") + messageSize = 1 + } else { + messageSize, err = strconv.Atoi(messageSizeStr) + if err != nil { + fmt.Printf("Warning: invalid message.size %s, defaulting to 1\n", messageSizeStr) + messageSize = 1 + } } value := make([]float64, messageSize) pt := decryptor.DecryptNew(ct) diff --git a/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD b/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD index dce52b6ea7..ebe7f943b1 100644 --- a/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD +++ b/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD @@ -17,6 +17,7 @@ heir_lattigo_lib( "--mlir-to-ckks=ciphertext-degree=2048 \ encryption-technique-extended=true \ plaintext-execution-result-file-name=$(location @heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log)", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8f.mlir", diff --git a/tests/Examples/openfhe/bfv/debug_helper.cpp b/tests/Examples/openfhe/bfv/debug_helper.cpp index 54bea09083..42bb774127 100644 --- a/tests/Examples/openfhe/bfv/debug_helper.cpp +++ b/tests/Examples/openfhe/bfv/debug_helper.cpp @@ -133,3 +133,11 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, } #endif } + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap) { + if (!cts.empty()) { + __heir_debug(cc, sk, cts[0], debugAttrMap); + } +} diff --git a/tests/Examples/openfhe/bfv/debug_helper.h b/tests/Examples/openfhe/bfv/debug_helper.h index c245d71bd9..28f2d3ebfd 100644 --- a/tests/Examples/openfhe/bfv/debug_helper.h +++ b/tests/Examples/openfhe/bfv/debug_helper.h @@ -3,6 +3,7 @@ #include #include +#include #include "src/pke/include/openfhe.h" // from @openfhe @@ -12,5 +13,8 @@ using PrivateKeyT = lbcrypto::PrivateKey; void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap); #endif // TESTS_EXAMPLES_OPENFHE_BFV_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD b/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD index 36c6fe8aa6..42848f60e2 100644 --- a/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD +++ b/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD @@ -8,6 +8,7 @@ openfhe_end_to_end_test( heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=bfv", "--mlir-to-bfv=ciphertext-degree=8192 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], heir_translate_flags = [ @@ -20,13 +21,17 @@ openfhe_end_to_end_test( openfhe_end_to_end_test( name = "dot_product_8_default_debug_test", - generate_debug_helper = True, generated_lib_header = "dot_product_8_default_debug_lib.h", heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=bfv", "--mlir-to-bfv=ciphertext-degree=8192 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], + heir_translate_flags = [ + "--openfhe-debug-helper-include-path=tests/Examples/openfhe/bfv/debug_helper.h", + ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", test_src = "dot_product_8_default_debug_test.cpp", + deps = ["@heir//tests/Examples/openfhe/bfv:debug_helper"], ) diff --git a/tests/Examples/openfhe/bgv/debug_helper.cpp b/tests/Examples/openfhe/bgv/debug_helper.cpp index 5cb1109141..0ff8cbccc7 100644 --- a/tests/Examples/openfhe/bgv/debug_helper.cpp +++ b/tests/Examples/openfhe/bgv/debug_helper.cpp @@ -107,3 +107,11 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, } #endif } + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap) { + if (!cts.empty()) { + __heir_debug(cc, sk, cts[0], debugAttrMap); + } +} diff --git a/tests/Examples/openfhe/bgv/debug_helper.h b/tests/Examples/openfhe/bgv/debug_helper.h index 027e5e39d8..4466fe2aaf 100644 --- a/tests/Examples/openfhe/bgv/debug_helper.h +++ b/tests/Examples/openfhe/bgv/debug_helper.h @@ -3,6 +3,7 @@ #include #include +#include #include "src/pke/include/openfhe.h" // from @openfhe @@ -12,5 +13,8 @@ using PrivateKeyT = lbcrypto::PrivateKey; void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap); #endif // TESTS_EXAMPLES_OPENFHE_BGV_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD b/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD index fe78c46ad4..3f9dffceff 100644 --- a/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD +++ b/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD @@ -8,6 +8,7 @@ openfhe_end_to_end_test( heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=bgv", "--mlir-to-bgv=ciphertext-degree=8192 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], heir_translate_flags = [ @@ -20,13 +21,17 @@ openfhe_end_to_end_test( openfhe_end_to_end_test( name = "dot_product_8_default_debug_test", - generate_debug_helper = True, generated_lib_header = "dot_product_8_default_debug_lib.h", heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=bgv", - "--mlir-to-bfv=ciphertext-degree=8192 annotate-noise-bound=true", + "--mlir-to-bgv=ciphertext-degree=8192 annotate-noise-bound=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], + heir_translate_flags = [ + "--openfhe-debug-helper-include-path=tests/Examples/openfhe/bgv/debug_helper.h", + ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", test_src = "dot_product_8_default_debug_test.cpp", + deps = ["@heir//tests/Examples/openfhe/bgv:debug_helper"], ) diff --git a/tests/Examples/openfhe/ckks/debug_helper.cpp b/tests/Examples/openfhe/ckks/debug_helper.cpp index 42eb6b2663..cacfd15a8e 100644 --- a/tests/Examples/openfhe/ckks/debug_helper.cpp +++ b/tests/Examples/openfhe/ckks/debug_helper.cpp @@ -133,3 +133,11 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, #endif #endif } + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap) { + if (!cts.empty()) { + __heir_debug(cc, sk, cts[0], debugAttrMap); + } +} diff --git a/tests/Examples/openfhe/ckks/debug_helper.h b/tests/Examples/openfhe/ckks/debug_helper.h index ecc42e3ce6..7fe13e1c5c 100644 --- a/tests/Examples/openfhe/ckks/debug_helper.h +++ b/tests/Examples/openfhe/ckks/debug_helper.h @@ -3,6 +3,7 @@ #include #include +#include #include "src/pke/include/openfhe.h" // from @openfhe @@ -12,5 +13,8 @@ using PrivateKeyT = lbcrypto::PrivateKey; void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap); #endif // TESTS_EXAMPLES_OPENFHE_CKKS_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD b/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD index 1cb2a7c919..c6f8530521 100644 --- a/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD +++ b/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD @@ -12,6 +12,7 @@ openfhe_end_to_end_test( "--annotate-module=backend=openfhe scheme=ckks", "--mlir-to-ckks=ciphertext-degree=8 \ plaintext-execution-result-file-name=$(location @heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log)", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], heir_translate_flags = [ @@ -27,14 +28,18 @@ openfhe_end_to_end_test( data = [ "@heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log", ], - generate_debug_helper = True, generated_lib_header = "dot_product_8f_default_debug_lib.h", heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=ckks", "--mlir-to-ckks=ciphertext-degree=8 \ plaintext-execution-result-file-name=$(location @heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log)", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], + heir_translate_flags = [ + "--openfhe-debug-helper-include-path=tests/Examples/openfhe/ckks/debug_helper.h", + ], mlir_src = "@heir//tests/Examples/common:dot_product_8f.mlir", test_src = "dot_product_8f_default_debug_test.cpp", + deps = ["@heir//tests/Examples/openfhe/ckks:debug_helper"], ) diff --git a/tests/Examples/openfhe/ckks/loop_support/BUILD b/tests/Examples/openfhe/ckks/loop_support/BUILD index 348360829e..ab0eb9e5ca 100644 --- a/tests/Examples/openfhe/ckks/loop_support/BUILD +++ b/tests/Examples/openfhe/ckks/loop_support/BUILD @@ -9,6 +9,7 @@ openfhe_end_to_end_test( heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=ckks", "--mlir-to-ckks=ciphertext-degree=8 scaling-mod-bits=55 first-mod-bits=60 level-budget=3 modulus-switch-after-mul=true experimental-disable-loop-unroll=true", + "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-openfhe=insert-debug-handler-calls=true", ], heir_translate_flags = [ diff --git a/tests/Examples/openfhe/ckks/loop_support/debug_helper.cpp b/tests/Examples/openfhe/ckks/loop_support/debug_helper.cpp index ef4ddc9557..dd8a0df856 100644 --- a/tests/Examples/openfhe/ckks/loop_support/debug_helper.cpp +++ b/tests/Examples/openfhe/ckks/loop_support/debug_helper.cpp @@ -41,3 +41,13 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, std::cout << "]\n"; std::cout << " Scale: " << log2(ct->GetScalingFactor()) << std::endl; } + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap) { + std::cout << "Vector of Ciphertexts (size " << cts.size() << ")" << std::endl; + for (size_t i = 0; i < cts.size(); ++i) { + std::cout << "Element " << i << ":" << std::endl; + __heir_debug(cc, sk, cts[i], debugAttrMap); + } +} diff --git a/tests/Examples/openfhe/ckks/loop_support/debug_helper.h b/tests/Examples/openfhe/ckks/loop_support/debug_helper.h index db22d593fb..c09d735091 100644 --- a/tests/Examples/openfhe/ckks/loop_support/debug_helper.h +++ b/tests/Examples/openfhe/ckks/loop_support/debug_helper.h @@ -13,4 +13,8 @@ using PrivateKeyT = lbcrypto::PrivateKey; void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector cts, + const std::map& debugAttrMap); + #endif // TESTS_EXAMPLES_OPENFHE_CKKS_LOOP_SUPPORT_DEBUG_HELPER_H_