diff --git a/lib/Conversions/CheddarToEmitC/BUILD b/lib/Conversions/CheddarToEmitC/BUILD new file mode 100644 index 0000000000..0a1df5a402 --- /dev/null +++ b/lib/Conversions/CheddarToEmitC/BUILD @@ -0,0 +1,32 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "CheddarToEmitC", + srcs = ["CheddarToEmitC.cpp"], + hdrs = ["CheddarToEmitC.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/Cheddar/IR:Dialect", + "@heir//lib/Utils:ConversionUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:EmitCDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +add_heir_transforms( + header_filename = "CheddarToEmitC.h.inc", + pass_name = "CheddarToEmitC", + td_file = "CheddarToEmitC.td", +) diff --git a/lib/Conversions/CheddarToEmitC/CheddarToEmitC.cpp b/lib/Conversions/CheddarToEmitC/CheddarToEmitC.cpp new file mode 100644 index 0000000000..7b8841f838 --- /dev/null +++ b/lib/Conversions/CheddarToEmitC/CheddarToEmitC.cpp @@ -0,0 +1,996 @@ +#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h" + +#include + +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Utils/ConversionUtils.h" +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/EmitC/IR/EmitC.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir { + +#define GEN_PASS_DEF_CHEDDARTOEMITC +#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h.inc" + +namespace { + +using ::mlir::emitc::CallOpaqueOp; +using ::mlir::emitc::LValueType; +using ::mlir::emitc::OpaqueAttr; +using ::mlir::emitc::OpaqueType; +using ::mlir::emitc::PointerType; +using ::mlir::emitc::VariableOp; +using ::mlir::emitc::VerbatimOp; + +// Returns true if `t` is an opaque type whose textual name is one of the +// move-only CHEDDAR payload types. Used by the type converter (to decide +// memref-of-cheddar -> emitc.array vs memref-of-emitc.opaque), the +// DPS-lifting post-pass, and the load-elision pass. +// +// EvaluationKey/Plaintext/Ciphertext/Constant are the move-only CHEDDAR +// payload types that the out-param pattern produces (as `T out;` locals and, +// at function boundaries, as `const T&` inputs / `T&` out-params). EvkMap and +// Encoder are also non-copy-assignable at the C++ level, but they are never +// produced as out-params -- they only ever appear as inputs -- so they are +// handled separately by `isConstRefBoundaryOpaque` (const-ref arg tightening) +// rather than here. +bool isMoveOnlyOpaque(Type t, StringRef& nameOut) { + auto opaqueT = dyn_cast(t); + if (!opaqueT) return false; + StringRef name = opaqueT.getValue(); + if (name == "Ciphertext" || name == "Plaintext" || + name == "Constant" || name == "EvaluationKey") { + nameOut = name; + return true; + } + return false; +} + +// EvkMap is move-only (copy deleted, no move-assignment) and Encoder is a +// non-assignable view (it holds reference members). Both must be passed by +// `const T&` at a function boundary: a by-value parameter would force an +// unnecessary move/copy at the call site (and, for EvkMap, a copy is deleted). +// Unlike the payload types they are never returned via the out-param pattern, +// so this predicate is consulted only for input tightening. +bool isConstRefBoundaryOpaque(Type t, StringRef& nameOut) { + auto opaqueT = dyn_cast(t); + if (!opaqueT) return false; + StringRef name = opaqueT.getValue(); + if (name == "EvkMap" || name == "Encoder") { + nameOut = name; + return true; + } + return false; +} + +// Returns true if `t` is an emitc.array whose element type is a move-only +// opaque. Used by the DPS-lift post-pass to recognise function returns that +// must be lifted to `std::array&` out-params (a C array cannot be +// returned by value in C++). +bool isMoveOnlyArray(Type t, StringRef& eltNameOut, int64_t& sizeOut) { + auto arrayT = dyn_cast(t); + if (!arrayT) return false; + if (arrayT.getShape().size() != 1) return false; + StringRef name; + if (!isMoveOnlyOpaque(arrayT.getElementType(), name)) return false; + eltNameOut = name; + sizeOut = arrayT.getShape()[0]; + return true; +} + +// Cheddar types map to the textual C++ type that the CHEDDAR library uses. +// Move-only types (Ciphertext/Plaintext/Constant) are *also* mapped to +// `opaque`; local variables wrap them in `lvalue` only at the point of +// declaration (via `emitc.variable`). +class TypeConverterImpl : public TypeConverter { + public: + explicit TypeConverterImpl(MLIRContext* ctx) { + addConversion([](Type t) { return t; }); + addConversion([ctx](cheddar::ParameterType) -> Type { + return OpaqueType::get(ctx, "Parameter"); + }); + addConversion([ctx](cheddar::ContextType) -> Type { + return PointerType::get(ctx, OpaqueType::get(ctx, "Context")); + }); + addConversion([ctx](cheddar::UserInterfaceType) -> Type { + return PointerType::get(ctx, OpaqueType::get(ctx, "UserInterface")); + }); + addConversion([ctx](cheddar::EncoderType) -> Type { + return OpaqueType::get(ctx, "Encoder"); + }); + addConversion([ctx](cheddar::EvkMapType) -> Type { + return OpaqueType::get(ctx, "EvkMap"); + }); + addConversion([ctx](cheddar::EvalKeyType) -> Type { + return OpaqueType::get(ctx, "EvaluationKey"); + }); + addConversion([ctx](cheddar::CiphertextType) -> Type { + return OpaqueType::get(ctx, "Ciphertext"); + }); + addConversion([ctx](cheddar::PlaintextType) -> Type { + return OpaqueType::get(ctx, "Plaintext"); + }); + addConversion([ctx](cheddar::ConstantType) -> Type { + return OpaqueType::get(ctx, "Constant"); + }); + // Tensor messages used by encode/decode become std::vector; the + // bitwidth choice is library-side and matches CHEDDAR's host-side API. + addConversion([ctx](RankedTensorType) -> Type { + return OpaqueType::get(ctx, "std::vector"); + }); + // memref<...x!cheddar.*> shows up after bufferization of looped kernels. + // For move-only element types we convert directly to emitc.array (a + // fixed-size C array) so the emitted C++ is `T name[N];` -- subscripting + // returns a reference and avoids any copy of the move-only element. + // + // For non-move-only element types we recursively rebuild a memref so + // downstream memref-to-emitc can lower it the standard way. + addConversion([this](MemRefType type) -> Type { + Type converted = this->convertType(type.getElementType()); + if (!converted) return Type(); + StringRef name; + if (isMoveOnlyOpaque(converted, name)) { + // Only static, rank-1 shapes are supported: emitc.array can't model + // dynamic extents, and the DPS boundary lift represents these as 1-D + // `std::array` out-params/args. A dynamic-shape move-only memref + // would otherwise fall through to a memref<...x!emitc.opaque> that + // stock MemRefToEmitC lowers with descriptors/copies (invalid for + // move-only payloads), and a higher-rank one to a multi-dim emitc.array + // the boundary lift can't represent. Return a null type so the + // conversion fails loudly rather than emitting broken C++. + if (!type.hasStaticShape() || type.getRank() != 1) return Type(); + return emitc::ArrayType::get(type.getShape(), converted); + } + return MemRefType::get(type.getShape(), converted, type.getLayout(), + type.getMemorySpace()); + }); + // `index` is what `memref.load`/`memref.store`/`tensor.extract` use as + // their index operand; SCFToEmitC + ArithToEmitC convert these to + // `emitc.size_t` and leave a `builtin.unrealized_conversion_cast` at the + // boundary. Hooking `index -> emitc.size_t` here lets our memref-op + // patterns consume the converted index directly via the adaptor and the + // dialect-conversion framework reconciles the cast away. + addConversion( + [ctx](IndexType) -> Type { return emitc::SizeTType::get(ctx); }); + } +}; + +// emitc::OpaqueType doesn't implement MemRefElementTypeInterface upstream, +// which would block our type converter from forming +// `memref>>` (the natural converted form of +// `memref`). The interface is marker-only, so an empty external +// model suffices. +struct EmitCOpaqueAsMemRefElement + : public mlir::MemRefElementTypeInterface::ExternalModel< + EmitCOpaqueAsMemRefElement, mlir::emitc::OpaqueType> {}; + +// Generic conversion pattern for memref ops carrying non-move-only element +// types: rebuilds the op with operand/result types converted by `tc`. +// Move-only memref ops are handled by the more specific patterns below. +struct ConvertGenericMemRefOp : public ConversionPattern { + ConvertGenericMemRefOp(const TypeConverter& tc, MLIRContext* ctx) + : ConversionPattern(tc, MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + if (op->getDialect()->getNamespace() != "memref") return failure(); + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + OperationState state(op->getLoc(), op->getName(), operands, newResultTypes, + op->getAttrs(), op->getSuccessors()); + for ([[maybe_unused]] auto& r : op->getRegions()) { + return failure(); + } + Operation* newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +// memref.alloc producing a memref of move-only cheddar type -> emitc.variable +// of emitc.array type. Emits `T name[N];` -- N default-constructed elements, +// move-only-safe, RAII-cleaned at scope exit. +struct ConvertMemRefAllocMoveOnly + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mlir::memref::AllocOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + Type converted = getTypeConverter()->convertType(op.getType()); + auto arrayTy = dyn_cast_or_null(converted); + if (!arrayTy) return failure(); + auto var = emitc::VariableOp::create( + rewriter, op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")); + rewriter.replaceOp(op, var.getResult()); + return success(); + } +}; + +// memref.load on an emitc.array of move-only opaque -> emitc.subscript + +// emitc.load. The subscript is alwaysInline=true so emission inlines `m[i]` +// at the use site, and the load is then erased by the move-only load-elision +// post-pass for consumers that accept lvalues directly (cheddar verbatims, +// memref.store-as-verbatim, return). +struct ConvertMemRefLoadMoveOnly + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mlir::memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value buf = adaptor.getMemref(); + auto arrayTy = dyn_cast(buf.getType()); + if (!arrayTy) return failure(); + StringRef name; + if (!isMoveOnlyOpaque(arrayTy.getElementType(), name)) return failure(); + auto lvalT = emitc::LValueType::get(arrayTy.getElementType()); + auto sub = emitc::SubscriptOp::create(rewriter, op.getLoc(), lvalT, buf, + adaptor.getIndices()); + rewriter.replaceOpWithNewOp(op, arrayTy.getElementType(), + sub.getResult()); + return success(); + } +}; + +// memref.store of a move-only value into an emitc.array -> subscript + a +// verbatim that emits `arr[i] = std::move(src);`. The load-elision post-pass +// will (correctly) substitute the load on `src` with the underlying lvalue, +// at which point the verbatim's two operands are both lvalues and the +// emission prints two variable names. +struct ConvertMemRefStoreMoveOnly + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mlir::memref::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value buf = adaptor.getMemref(); + auto arrayTy = dyn_cast(buf.getType()); + if (!arrayTy) return failure(); + StringRef name; + if (!isMoveOnlyOpaque(arrayTy.getElementType(), name)) return failure(); + auto lvalT = emitc::LValueType::get(arrayTy.getElementType()); + auto sub = emitc::SubscriptOp::create(rewriter, op.getLoc(), lvalT, buf, + adaptor.getIndices()); + emitc::VerbatimOp::create(rewriter, op.getLoc(), "{} = std::move({});", + ValueRange{sub.getResult(), adaptor.getValue()}); + rewriter.eraseOp(op); + return success(); + } +}; + +// memref.dealloc on a move-only emitc.array is a no-op at emission (the +// scope-bound emitc.variable's destructor handles cleanup). Erase the op. +struct EraseMemRefDeallocMoveOnly + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mlir::memref::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value buf = adaptor.getMemref(); + auto arrayTy = dyn_cast(buf.getType()); + if (!arrayTy) return failure(); + StringRef name; + if (!isMoveOnlyOpaque(arrayTy.getElementType(), name)) return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +// Declare a fresh local lvalue variable of value-type `t`. Patterns then emit +// a verbatim that initializes it (out-param call or assignment from a getter) +// and finish with `loadAfter` to obtain the loaded value to feed into +// `replaceOp`, matching the type converter's value-form output. +Value declareLocal(OpBuilder& b, Location loc, Type t) { + return VariableOp::create(b, loc, LValueType::get(t), + OpaqueAttr::get(b.getContext(), "")); +} + +Value loadAfter(OpBuilder& b, Location loc, Type t, Value lvalue) { + return mlir::emitc::LoadOp::create(b, loc, t, lvalue); +} + +// Emit `receiver->method(out, args..., extra);`. +void emitOutParamCall(OpBuilder& b, Location loc, Value receiver, + StringRef method, Value out, ValueRange args, + StringRef extra = "") { + std::string fmt; + llvm::raw_string_ostream os(fmt); + os << "{}->" << method << "({}"; + for (size_t i = 0; i < args.size(); ++i) os << ", {}"; + if (!extra.empty()) os << ", " << extra; + os << ");"; + SmallVector v{receiver, out}; + v.append(args.begin(), args.end()); + VerbatimOp::create(b, loc, os.str(), v); +} + +std::string intLit(IntegerAttr a) { return std::to_string(a.getInt()); } + +std::string floatLit(FloatAttr a) { + return llvm::formatv("{0}", a.getValueAsDouble()).str(); +} + +// `{}` is the operand placeholder in `emitc.verbatim` format strings. A +// literal `{` must be written `{{`; a literal `}` is emitted as-is (emitc does +// NOT collapse `}}` to `}`), so the closing brace of an initializer list stays +// single -- doubling it would emit a stray `}`. +std::string i32ArrayLit(DenseI32ArrayAttr a) { + std::string s = "{{"; + for (size_t i = 0; i < a.size(); ++i) { + if (i > 0) s += ", "; + s += std::to_string(a[i]); + } + return s + "}"; +} + +std::string floatArrayLit(ArrayAttr a) { + std::string s = "{{"; + for (size_t i = 0; i < a.size(); ++i) { + if (i > 0) s += ", "; + s += + llvm::formatv("{0:f1}", cast(a[i]).getValueAsDouble()).str(); + } + return s + "}"; +} + +// Generic out-param pattern: first operand is the receiver, remaining operands +// are inputs, the op produces one cheddar value. +template +struct OutParamPattern : public OpConversionPattern { + OutParamPattern(const TypeConverter& tc, MLIRContext* ctx, StringRef method) + : OpConversionPattern(tc, ctx), method(method.str()) {} + + LogicalResult matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + auto operands = adaptor.getOperands(); + emitOutParamCall(rewriter, op.getLoc(), operands[0], method, out, + operands.drop_front()); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } + + std::string method; +}; + +// CreateContext: static factory, rendered as `T x = T::Create(args);`. +struct ConvertCreateContext + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::CreateContextOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + auto call = CallOpaqueOp::create( + rewriter, op.getLoc(), TypeRange{t}, + rewriter.getStringAttr("Context::Create"), + ValueRange{adaptor.getParams()}, + /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +// The getter-style setup ops -- get_evk_map, get_mult_key, get_encoder, +// create_user_interface -- are NOT supported by this lowering. Their CHEDDAR +// C++ counterparts hand back a `const&` to a move-only / non-assignable value +// (EvkMap and EvaluationKey are move-only; Encoder is a view with reference +// members; UserInterface is move-only), so the value-materialising shape a +// naive lowering emits -- `T tmp; tmp = recv->Get();` -- cannot compile. +// Supporting them needs inline-at-use emission (the way HRot already inlines +// `ui->GetRotationKey(d)`) or emitc.expression. Real kernels avoid these +// entirely: they take keys/maps/encoders as function arguments or look keys +// up inline. +// +// Reject them with a clear diagnostic in a pre-pass walk rather than from a +// conversion pattern: the dialect-conversion framework discards diagnostics +// emitted by a pattern that returns failure, so a pattern-based error would be +// swallowed in favour of a generic "failed to legalize". Returns true (and +// emits an error on each) if any unsupported getter is present. +bool diagnoseUnsupportedGetters(Operation* root) { + bool found = false; + root->walk([&](Operation* op) { + if (isa(op)) { + op->emitError() + << "cheddar-to-emitc: lowering of '" << op->getName().getStringRef() + << "' is not supported: it returns a const reference to a " + "move-only/non-assignable value, which cannot be materialised " + "into a local without a copy. Pass the key/map/encoder as a " + "function argument, or look it up inline at the use site."; + found = true; + } + }); + return found; +} + +struct ConvertPrepareRotKey + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::PrepareRotKeyOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + std::string extra = + intLit(op.getDistanceAttr()) + ", " + intLit(op.getMaxLevelAttr()); + VerbatimOp::create(rewriter, op.getLoc(), + "{}->PrepareRotationKey(" + extra + ");", + ValueRange{adaptor.getUi()}); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertEncode : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::EncodeOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + std::string extra = + intLit(op.getLevelAttr()) + ", " + floatLit(op.getScaleAttr()); + emitOutParamCall(rewriter, op.getLoc(), adaptor.getEncoder(), "Encode", out, + ValueRange{adaptor.getMessage()}, extra); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertEncodeConstant + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::EncodeConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + std::string extra = + intLit(op.getLevelAttr()) + ", " + floatLit(op.getScaleAttr()); + emitOutParamCall(rewriter, op.getLoc(), adaptor.getEncoder(), + "EncodeConstant", out, ValueRange{adaptor.getValue()}, + extra); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertLevelDown : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::LevelDownOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + emitOutParamCall(rewriter, op.getLoc(), adaptor.getCtx(), "LevelDown", out, + ValueRange{adaptor.getInput()}, + intLit(op.getTargetLevelAttr())); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertHMult : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::HMultOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + StringRef extra = op.getRescale() ? "true" : "false"; + emitOutParamCall( + rewriter, op.getLoc(), adaptor.getCtx(), "HMult", out, + ValueRange{adaptor.getLhs(), adaptor.getRhs(), adaptor.getMultKey()}, + extra); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +// HRot/HRotAdd/HConj/HConjAdd: the dialect no longer carries an explicit +// key SSA operand. At emission time, we look up `ui->GetRotationKey(d)` (or +// `ui->GetConjugationKey()`) inline; the UserInterface is taken from the +// enclosing function's argument list via `getContextualArgFromFunc`. +// +// `getContextualArgFromFunc` walks the original (pre-conversion) func block +// and looks for an arg of the *converted* type; the pass runs as a partial +// dialect conversion so the func signature has already been converted by the +// structural patterns when these matchers fire. + +Value findUi(Operation* op, const TypeConverter& tc) { + Type uiType = + tc.convertType(cheddar::UserInterfaceType::get(op->getContext())); + auto r = getContextualArgFromFunc(op, uiType); + if (failed(r)) return Value{}; + return r.value(); +} + +struct ConvertHRot : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::HRotOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value ui = findUi(op, *this->typeConverter); + if (!ui) + return op.emitOpError("enclosing function is missing UserInterface arg"); + + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + + if (auto sd = op.getStaticDistanceAttr()) { + std::string d = intLit(sd); + VerbatimOp::create( + rewriter, op.getLoc(), + "{}->HRot({}, {}, {}->GetRotationKey(" + d + "), " + d + ");", + ValueRange{adaptor.getCtx(), out, adaptor.getInput(), ui}); + } else { + Value dyn = adaptor.getDynamicDistance(); + VerbatimOp::create( + rewriter, op.getLoc(), + "{}->HRot({}, {}, {}->GetRotationKey({}), {});", + ValueRange{adaptor.getCtx(), out, adaptor.getInput(), ui, dyn, dyn}); + } + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertHRotAdd : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::HRotAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value ui = findUi(op, *this->typeConverter); + if (!ui) + return op.emitOpError("enclosing function is missing UserInterface arg"); + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + std::string d = intLit(op.getDistanceAttr()); + VerbatimOp::create( + rewriter, op.getLoc(), + "{}->HRotAdd({}, {}, {}, {}->GetRotationKey(" + d + "), " + d + ");", + ValueRange{adaptor.getCtx(), out, adaptor.getInput(), + adaptor.getAddend(), ui}); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertHConj : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::HConjOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value ui = findUi(op, *this->typeConverter); + if (!ui) + return op.emitOpError("enclosing function is missing UserInterface arg"); + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + VerbatimOp::create( + rewriter, op.getLoc(), "{}->HConj({}, {}, {}->GetConjugationKey());", + ValueRange{adaptor.getCtx(), out, adaptor.getInput(), ui}); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertHConjAdd : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::HConjAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Value ui = findUi(op, *this->typeConverter); + if (!ui) + return op.emitOpError("enclosing function is missing UserInterface arg"); + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + VerbatimOp::create(rewriter, op.getLoc(), + "{}->HConjAdd({}, {}, {}, {}->GetConjugationKey());", + ValueRange{adaptor.getCtx(), out, adaptor.getInput(), + adaptor.getAddend(), ui}); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +// `ctx->MadUnsafe(acc, in, c);` is an in-place mutation: the SSA result is +// the same value as the input accumulator after the call. +struct ConvertMadUnsafe : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::MadUnsafeOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + VerbatimOp::create(rewriter, op.getLoc(), "{}->MadUnsafe({}, {}, {});", + ValueRange{adaptor.getCtx(), adaptor.getAccumulator(), + adaptor.getInput(), adaptor.getConstant()}); + rewriter.replaceOp(op, adaptor.getAccumulator()); + return success(); + } +}; + +// `ctx->Boot(res, input, evk_map);`. The CHEDDAR runtime resolves whether +// `ctx` is a regular `Context` or a `BootContext`; the dialect carries only +// a single Context type today. +struct ConvertBoot : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::BootOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + emitOutParamCall(rewriter, op.getLoc(), adaptor.getCtx(), "Boot", out, + ValueRange{adaptor.getInput(), adaptor.getEvkMap()}); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertLinearTransform + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::LinearTransformOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + std::string extra = i32ArrayLit(op.getDiagonalIndicesAttr()) + ", " + + intLit(op.getLevelAttr()) + ", " + + intLit(op.getLogBabyStepGiantStepRatioAttr()); + emitOutParamCall(rewriter, op.getLoc(), adaptor.getCtx(), "LinearTransform", + out, + ValueRange{adaptor.getInput(), adaptor.getEvkMap(), + adaptor.getDiagonals()}, + extra); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct ConvertEvalPoly : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + cheddar::EvalPolyOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Type t = this->typeConverter->convertType(op.getResult().getType()); + Value out = declareLocal(rewriter, op.getLoc(), t); + std::string extra = floatArrayLit(op.getCoefficientsAttr()) + ", " + + intLit(op.getLevelAttr()); + emitOutParamCall(rewriter, op.getLoc(), adaptor.getCtx(), "EvalPoly", out, + ValueRange{adaptor.getInput(), adaptor.getEvkMap()}, + extra); + rewriter.replaceOp(op, loadAfter(rewriter, op.getLoc(), t, out)); + return success(); + } +}; + +struct CheddarToEmitCPass + : public impl::CheddarToEmitCBase { + using CheddarToEmitCBase::CheddarToEmitCBase; + + void runOnOperation() override { + auto* ctx = &getContext(); + + // Reject unsupported getter-style setup ops before conversion (see + // diagnoseUnsupportedGetters). + if (diagnoseUnsupportedGetters(getOperation())) { + signalPassFailure(); + return; + } + + TypeConverterImpl tc(ctx); + + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect<::mlir::emitc::EmitCDialect>(); + target.addLegalDialect<::mlir::func::FuncDialect>(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return tc.isSignatureLegal(op.getFunctionType()) && + tc.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](Operation* op) { return tc.isLegal(op); }); + // memref ops carrying Cheddar element types after bufferization of looped + // kernels are legal only once their element types have been converted + // (Cheddar element type -> EmitC opaque). The recursive MemRefType + // converter in TypeConverterImpl plus the structural conversion patterns + // do the rewrite. + target.addDynamicallyLegalDialect<::mlir::memref::MemRefDialect>( + [&](Operation* op) { return tc.isLegal(op); }); + + RewritePatternSet patterns(ctx); + addStructuralConversionPatterns(tc, patterns, target); + // Memref ops carrying cheddar element types: move-only-aware patterns + // emit emitc.variable/subscript/verbatim directly (handling C++ move + // semantics correctly). The generic fallback rebuilds non-move-only + // memref ops with converted types for downstream memref-to-emitc to + // lower the standard way. + patterns.add(tc, + ctx); + patterns.add(tc, ctx); + + // get_evk_map / get_mult_key / get_encoder / create_user_interface are + // rejected up front by diagnoseUnsupportedGetters, so no patterns are + // registered for them here. + patterns + .add(tc, ctx); + + // The remaining ops follow the uniform out-param pattern; first operand + // is the receiver, remaining operands are inputs. + patterns.add>(tc, ctx, "Add"); + patterns.add>(tc, ctx, "Sub"); + patterns.add>(tc, ctx, "Mult"); + // CHEDDAR's Context uses C++ overloading for ct+pt / ct+const variants: + // `void Add(Ct&, const Ct&, const Pt&)`, `void Add(Ct&, const Ct&, const + // Const&)`, etc. No separate `AddPlain`/`AddConst` methods exist on + // Context, so dispatch the dialect's `*_plain` / `*_const` ops to the + // base name and let the C++ compiler pick the overload by arg type. + patterns.add>(tc, ctx, "Add"); + patterns.add>(tc, ctx, "Sub"); + patterns.add>(tc, ctx, "Mult"); + patterns.add>(tc, ctx, "Add"); + patterns.add>(tc, ctx, "Mult"); + patterns.add>(tc, ctx, "Neg"); + patterns.add>(tc, ctx, "Rescale"); + patterns.add>(tc, ctx, + "Relinearize"); + patterns.add>( + tc, ctx, "RelinearizeRescale"); + patterns.add>(tc, ctx, "Decode"); + patterns.add>(tc, ctx, "Encrypt"); + patterns.add>(tc, ctx, "Decrypt"); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + + // CHEDDAR's Ciphertext/Plaintext/Constant are move-only in C++. The + // structural conversion above leaves func.func ops with `T` (by-value) + // arg and result types, which the C++ emitter renders as + // T add_kernel(..., T a, T b) { T r; ctx->Add(r, a, b); T tmp = r; + // return tmp; } + // — both the by-value parameters and the `T tmp = r;` copy-init at the + // return rely on a copy ctor that doesn't exist. + // + // Lift each such func to destination-passing style at the EmitC level: + // * move-only arg types become `const T&`, + // * move-only result types are dropped and re-appended as trailing + // `T&` out-params, + // * each `func.return %r` becomes `out = std::move();` (using the + // load's lvalue source when possible to avoid materialising the + // `T tmp = r;` copy) followed by a bare `func.return`. + // + // Local variables and intermediate values inside the body stay as plain + // `T` lvalues — those go through the move-assignment in `ctx->Op(out, + // ...)` which is valid. + // The patterns above use the `T tmp; ctx->Op(tmp, ...); emitc.load(tmp)` + // shape, where the `emitc.load` is intended to materialise a value-typed + // SSA def the next pattern can consume. In C++ this renders as + // `T downstream = tmp;` -- copy-init, which doesn't compile for move-only + // types. Walk the IR and elide every load of a move-only opaque, + // replacing all uses of the load's result with the source lvalue + // directly. Downstream consumers are emitc.verbatim ops (and, after the + // DPS lift below, func.return rewrites), both of which only ever print + // the operand name -- so the resulting C++ ends up referencing the + // local variable by name, binding naturally to the `const T&` parameters + // of the receiver methods. + getOperation()->walk([](emitc::LoadOp load) { + StringRef name; + if (!isMoveOnlyOpaque(load.getType(), name)) return; + load.getResult().replaceAllUsesWith(load.getOperand()); + load.erase(); + }); + + auto& ctxRef = *ctx; + getOperation()->walk([&ctxRef](func::FuncOp op) { + if (op.isExternal()) return; + auto funcType = op.getFunctionType(); + Block& entry = op.getBody().front(); + + SmallVector returns; + op.walk([&](func::ReturnOp r) { returns.push_back(r); }); + + // A move-only payload result whose return operand is an entry block + // argument means the function hands back storage it was given: e.g. + // `mad_unsafe` mutates its accumulator argument in place and returns it, + // or a passthrough returns an argument unchanged. Such an argument is + // the destination -- it must be a mutable `T&` (not `const T&`) so the + // in-place mutation binds, and it needs no separate out-param. Detect + // these first so Pass 1 tightens them correctly. Only treated as + // in-place when *every* return agrees, so a signature is never left + // inconsistent. + unsigned numResults = funcType.getNumResults(); + SmallVector resultIsInout(numResults, false); + SmallVector argIsInout(funcType.getNumInputs(), false); + for (unsigned i = 0; i < numResults; ++i) { + StringRef name; + if (!isMoveOnlyOpaque(funcType.getResult(i), name)) continue; + int argIdx = -1; + bool ok = !returns.empty(); + for (auto ret : returns) { + auto ba = dyn_cast(ret.getOperand(i)); + if (!ba || ba.getOwner() != &entry) { + ok = false; + break; + } + if (argIdx < 0) + argIdx = ba.getArgNumber(); + else if (argIdx != static_cast(ba.getArgNumber())) { + ok = false; + break; + } + } + if (ok && argIdx >= 0) { + resultIsInout[i] = true; + argIsInout[argIdx] = true; + } + } + + // Pass 1: tighten input types at the C++ boundary. Move-only payload + // scalars become `const T&`, except in-place/returned args which stay + // mutable `T&`. emitc.array args (from memref boundary + // lowering of read-only inputs) become `const std::array, N>&`. + // EvkMap/Encoder become `const T&`. emitc.subscript on the original + // emitc.array stays valid after the swap because subscript also accepts + // EmitC_OpaqueType as its base operand. + SmallVector newInputs(funcType.getInputs().begin(), + funcType.getInputs().end()); + bool inputsChanged = false; + for (size_t i = 0; i < newInputs.size(); ++i) { + StringRef name; + int64_t arraySize = 0; + if (isMoveOnlyOpaque(newInputs[i], name)) { + std::string typeName = argIsInout[i] ? (name + "&").str() + : ("const " + name + "&").str(); + newInputs[i] = emitc::OpaqueType::get(&ctxRef, typeName); + entry.getArgument(i).setType(newInputs[i]); + inputsChanged = true; + } else if (isMoveOnlyArray(newInputs[i], name, arraySize)) { + std::string typeName = ("const std::array<" + name + ", " + + std::to_string(arraySize) + ">&") + .str(); + newInputs[i] = emitc::OpaqueType::get(&ctxRef, typeName); + entry.getArgument(i).setType(newInputs[i]); + inputsChanged = true; + } else if (isConstRefBoundaryOpaque(newInputs[i], name)) { + newInputs[i] = + emitc::OpaqueType::get(&ctxRef, ("const " + name + "&").str()); + entry.getArgument(i).setType(newInputs[i]); + inputsChanged = true; + } + } + + // Pass 2: lift move-only results off the return type. In-place results + // are dropped entirely (the value already lives in the mutable arg from + // Pass 1). Other move-only scalars/arrays are re-appended as trailing + // `T&` / `std::array&` out-params, since move-only values and C + // arrays can't be returned by value in C++. + enum DpsKind { kInout, kScalar, kArray }; + SmallVector dpsResultIdxs; + SmallVector dpsKind; + SmallVector dpsOutParam; // null for kInout + SmallVector retainedResults; + SmallVector appendedInputs; + for (auto [i, t] : llvm::enumerate(funcType.getResults())) { + StringRef name; + int64_t arraySize = 0; + if (resultIsInout[i]) { + dpsResultIdxs.push_back(i); + dpsKind.push_back(kInout); + dpsOutParam.push_back(Value{}); + continue; + } + if (isMoveOnlyOpaque(t, name)) { + Type refT = emitc::OpaqueType::get(&ctxRef, (name + "&").str()); + appendedInputs.push_back(refT); + dpsResultIdxs.push_back(i); + dpsKind.push_back(kScalar); + dpsOutParam.push_back(entry.addArgument(refT, op.getLoc())); + continue; + } + if (isMoveOnlyArray(t, name, arraySize)) { + std::string typeName = + ("std::array<" + name + ", " + std::to_string(arraySize) + ">&") + .str(); + Type refT = emitc::OpaqueType::get(&ctxRef, typeName); + appendedInputs.push_back(refT); + dpsResultIdxs.push_back(i); + dpsKind.push_back(kArray); + dpsOutParam.push_back(entry.addArgument(refT, op.getLoc())); + continue; + } + retainedResults.push_back(t); + } + + if (!inputsChanged && dpsResultIdxs.empty()) return; + + SmallVector finalInputs; + finalInputs.append(newInputs.begin(), newInputs.end()); + finalInputs.append(appendedInputs.begin(), appendedInputs.end()); + op.setType(FunctionType::get(&ctxRef, finalInputs, retainedResults)); + + if (dpsResultIdxs.empty()) return; + + // Rewrite each func.return: move scalar/array results into their + // out-params, drop in-place results (already in the mutable arg), retain + // the rest. + SmallVector loadsToMaybeErase; + for (auto ret : returns) { + OpBuilder b(ret); + SmallVector retained; + size_t cursor = 0; + for (auto [i, val] : llvm::enumerate(ret.getOperands())) { + if (cursor < dpsResultIdxs.size() && dpsResultIdxs[cursor] == i) { + switch (dpsKind[cursor]) { + case kInout: + // Value already lives in the mutable in-place argument; the + // return just drops it. + break; + case kArray: + // Bulk move from the local emitc.array into the std::array + // out-param via `std::move(begin, end, out.begin())`. + emitc::VerbatimOp::create( + b, ret.getLoc(), + "std::move(std::begin({}), std::end({}), {}.begin());", + ValueRange{val, val, dpsOutParam[cursor]}); + break; + case kScalar: { + // Prefer the load's source lvalue (skips a spurious copy); + // fall back to the value otherwise. + Value source = val; + if (auto loadOp = val.getDefiningOp()) { + source = loadOp.getOperand(); + loadsToMaybeErase.push_back(loadOp); + } + emitc::VerbatimOp::create( + b, ret.getLoc(), "{} = std::move({});", + ValueRange{dpsOutParam[cursor], source}); + break; + } + } + ++cursor; + } else { + retained.push_back(val); + } + } + func::ReturnOp::create(b, ret.getLoc(), retained); + ret.erase(); + } + for (auto load : loadsToMaybeErase) { + if (load.use_empty()) load.erase(); + } + }); + } +}; + +} // namespace + +void registerCheddarToEmitCExternalModels(DialectRegistry& registry) { + registry.addExtension(+[](MLIRContext* ctx, mlir::emitc::EmitCDialect*) { + mlir::emitc::OpaqueType::attachInterface(*ctx); + }); +} + +} // namespace mlir::heir diff --git a/lib/Conversions/CheddarToEmitC/CheddarToEmitC.h b/lib/Conversions/CheddarToEmitC/CheddarToEmitC.h new file mode 100644 index 0000000000..970fb0adf4 --- /dev/null +++ b/lib/Conversions/CheddarToEmitC/CheddarToEmitC.h @@ -0,0 +1,23 @@ +#ifndef LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_H_ +#define LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_H_ + +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir { + +// Attaches MemRefElementTypeInterface as an external (marker-only) model to +// emitc::OpaqueType. Needed so that the cheddar-to-emitc type converter can +// form `memref>` as the converted form of +// `memref` after bufferization. Call once at tool startup. +void registerCheddarToEmitCExternalModels(DialectRegistry& registry); + +#define GEN_PASS_DECL +#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h.inc" + +} // namespace mlir::heir + +#endif // LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_H_ diff --git a/lib/Conversions/CheddarToEmitC/CheddarToEmitC.td b/lib/Conversions/CheddarToEmitC/CheddarToEmitC.td new file mode 100644 index 0000000000..216f72b55d --- /dev/null +++ b/lib/Conversions/CheddarToEmitC/CheddarToEmitC.td @@ -0,0 +1,22 @@ +#ifndef LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_TD_ +#define LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_TD_ + +include "mlir/Pass/PassBase.td" + +def CheddarToEmitC : Pass<"cheddar-to-emitc"> { + let summary = "Lower the cheddar dialect to EmitC."; + + let description = [{ + Translates each cheddar op into an out-parameter-style `Context`/ + `UserInterface` method call expressed in the EmitC dialect. + `mlir-translate --mlir-to-cpp` renders the resulting IR as host-side + C++ against the CHEDDAR library API. + }]; + + let dependentDialects = [ + "::mlir::emitc::EmitCDialect", + "::mlir::func::FuncDialect", + ]; +} + +#endif // LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_TD_ diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.td b/lib/Dialect/Cheddar/IR/CheddarTypes.td index 5f8b0b795f..526a436ed2 100644 --- a/lib/Dialect/Cheddar/IR/CheddarTypes.td +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.td @@ -4,13 +4,19 @@ include "CheddarDialect.td" include "lib/Dialect/HEIRInterfaces.td" +include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/IR/DialectBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpAsmInterface.td" // A base class for all types in this dialect +// +// All Cheddar types implement MemRefElementTypeInterface so they can appear as +// element types in `tensor` / `memref`, which is +// what the bufferization pipeline produces when lowering looped CKKS kernels. class Cheddar_Type traits = []> - : TypeDef { + : TypeDef { let mnemonic = typeMnemonic; let genMnemonicAlias = 1; diff --git a/tests/Conversions/CheddarToEmitC/BUILD b/tests/Conversions/CheddarToEmitC/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Conversions/CheddarToEmitC/boundary_tightening.mlir b/tests/Conversions/CheddarToEmitC/boundary_tightening.mlir new file mode 100644 index 0000000000..f4eccbd9b8 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/boundary_tightening.mlir @@ -0,0 +1,55 @@ +// RUN: heir-opt --cheddar-to-emitc --split-input-file %s | FileCheck %s + +// The mad_unsafe accumulator comes from a function argument; it is mutated in +// place by MadUnsafe and returned, so it must be lifted to a mutable +// `Ciphertext&` (the `<"Ciphertext` prefix anchors the match away from +// the `const Ciphertext` inputs) and the result dropped -- no out-param, no +// `std::move` at the return. +// CHECK: func.func @mad_arg +// CHECK-SAME: !emitc.opaque<"Ciphertext&"> +// CHECK-SAME: !emitc.opaque<"const Ciphertext&"> +// CHECK-SAME: !emitc.opaque<"const Constant&"> +// CHECK: emitc.verbatim "{}->MadUnsafe({}, {}, {});" +// CHECK-NOT: std::move +func.func @mad_arg(%ctx: !cheddar.context, %acc: !cheddar.ciphertext, + %in: !cheddar.ciphertext, %c: !cheddar.constant) + -> !cheddar.ciphertext { + %r = cheddar.mad_unsafe %ctx, %acc, %in, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// ----- + +// Returning a move-only argument unchanged lifts it to an in-place +// `Ciphertext&` out-param with no result and no copy. +// CHECK: func.func @identity +// CHECK-SAME: !emitc.opaque<"Ciphertext&"> +// CHECK-NOT: std::move +func.func @identity(%ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + return %ct : !cheddar.ciphertext +} + +// ----- + +// An EvkMap argument is move-only, so it tightens to `const EvkMap&` +// rather than staying a by-value parameter. +// CHECK: func.func @boot +// CHECK-SAME: !emitc.opaque<"const EvkMap&"> +func.func @boot(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + %0 = cheddar.boot %ctx, %ct, %evk + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %0 : !cheddar.ciphertext +} + +// ----- + +// An Encoder argument (a non-assignable view) tightens to `const Encoder&`. +// CHECK: func.func @encoder_arg +// CHECK-SAME: !emitc.opaque<"const Encoder&"> +func.func @encoder_arg(%enc: !cheddar.encoder, %ct: !cheddar.ciphertext) + -> !cheddar.ciphertext { + return %ct : !cheddar.ciphertext +} diff --git a/tests/Conversions/CheddarToEmitC/cheddar_to_emitc.mlir b/tests/Conversions/CheddarToEmitC/cheddar_to_emitc.mlir new file mode 100644 index 0000000000..2c3f1bcf11 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/cheddar_to_emitc.mlir @@ -0,0 +1,261 @@ +// RUN: heir-opt --cheddar-to-emitc %s | FileCheck %s + +// CreateContext is a static factory. PrepareRotKey is a void method. The +// getter-style setup ops (create_user_interface / get_encoder / get_evk_map / +// get_mult_key) are unsupported and rejected; see unsupported_getters.mlir. + +// CHECK: func.func @create_context +// CHECK: emitc.call_opaque "Context::Create" +func.func @create_context(%params: !cheddar.parameter) -> !cheddar.context { + %ctx = cheddar.create_context %params : (!cheddar.parameter) -> !cheddar.context + return %ctx : !cheddar.context +} + +// CHECK: func.func @prepare_keys +func.func @prepare_keys(%ui: !cheddar.user_interface) { + // CHECK: emitc.verbatim "{}->PrepareRotationKey(3, 5);" args %arg0 + cheddar.prepare_rot_key %ui {distance = 3 : i64, maxLevel = 5 : i64} : (!cheddar.user_interface) -> () + return +} + +// Encode/encrypt/decode/decrypt: out-param method calls. The attribute +// values for level and scale are inlined into the verbatim format string. + +// CHECK: func.func @encode_chain +func.func @encode_chain(%enc: !cheddar.encoder, %msg: tensor<4xf64>, + %ui: !cheddar.user_interface) + -> !cheddar.ciphertext { + // Note: scale is the C++ `double` value used by CHEDDAR's Encoder + // (Δ = 2^36 = 68719476736). + // CHECK: emitc.verbatim "{}->Encode({}, {}, 5, 68719476736.00);" + %pt = cheddar.encode %enc, %msg {level = 5 : i64, scale = 68719476736.0 : f64} + : (!cheddar.encoder, tensor<4xf64>) -> !cheddar.plaintext + // CHECK: emitc.verbatim "{}->Encrypt({}, {});" + %ct = cheddar.encrypt %ui, %pt + : (!cheddar.user_interface, !cheddar.plaintext) -> !cheddar.ciphertext + return %ct : !cheddar.ciphertext +} + +// CHECK: func.func @decode_chain +func.func @decode_chain(%enc: !cheddar.encoder, %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext) -> tensor<4xf64> { + // CHECK: emitc.verbatim "{}->Decrypt({}, {});" + %pt = cheddar.decrypt %ui, %ct + : (!cheddar.user_interface, !cheddar.ciphertext) -> !cheddar.plaintext + // CHECK: emitc.verbatim "{}->Decode({}, {});" + %msg = cheddar.decode %enc, %pt + : (!cheddar.encoder, !cheddar.plaintext) -> tensor<4xf64> + return %msg : tensor<4xf64> +} + +// Chaining: each cheddar op produces an `emitc.variable` (the C++ out-param +// destination), the verbatim writes into it, and an `emitc.load` reads it +// back as a value for downstream consumers. The next op's variable is +// declared after the previous load. + +// Chaining: each cheddar op declares its own `emitc.variable` (the C++ +// out-param destination) and writes into it via verbatim. Downstream +// consumers reference that variable by name -- the conversion elides the +// `emitc.load` that would otherwise materialise `Ciphertext tmp = +// addv;` (copy-init of a move-only type). +// CHECK: func.func @arith_chain +func.func @arith_chain(%ctx: !cheddar.context, %a: !cheddar.ciphertext, + %b: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: %[[ADDV:.*]] = "emitc.variable" + // CHECK-NEXT: emitc.verbatim "{}->Add({}, {}, {});" args %arg0, %[[ADDV]], %arg1, %arg2 + %r = cheddar.add %ctx, %a, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + // CHECK-NEXT: %[[MULTV:.*]] = "emitc.variable" + // CHECK-NEXT: emitc.verbatim "{}->Mult({}, {}, {});" args %arg0, %[[MULTV]], %[[ADDV]], %arg2 + %s = cheddar.mult %ctx, %r, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %s : !cheddar.ciphertext +} + +// CHECK: func.func @ct_pt_ct_const +func.func @ct_pt_ct_const(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext, %c: !cheddar.constant) + -> !cheddar.ciphertext { + // CHEDDAR's Context overloads Add/Sub/Mult on the second-operand type, so + // *_plain / *_const ops dispatch to the base name and rely on C++ to pick + // the right overload. + // CHECK: emitc.verbatim "{}->Add({}, {}, {});" + %r1 = cheddar.add_plain %ctx, %ct, %pt + : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->Sub({}, {}, {});" + %r2 = cheddar.sub_plain %ctx, %r1, %pt + : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->Mult({}, {}, {});" + %r3 = cheddar.mult_plain %ctx, %r2, %pt + : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->Add({}, {}, {});" + %r4 = cheddar.add_const %ctx, %r3, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->Mult({}, {}, {});" + %r5 = cheddar.mult_const %ctx, %r4, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %r5 : !cheddar.ciphertext +} + +// CHECK: func.func @unary +func.func @unary(%ctx: !cheddar.context, %ct: !cheddar.ciphertext) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->Neg({}, {});" + %n = cheddar.neg %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->Rescale({}, {});" + %r = cheddar.rescale %ctx, %n : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->LevelDown({}, {}, 2);" + %l = cheddar.level_down %ctx, %r {targetLevel = 2 : i64} + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %l : !cheddar.ciphertext +} + +// CHECK: func.func @relin +func.func @relin(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %k: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->Relinearize({}, {}, {});" + %r1 = cheddar.relinearize %ctx, %ct, %k + : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + // CHECK: emitc.verbatim "{}->RelinearizeRescale({}, {}, {});" + %r2 = cheddar.relinearize_rescale %ctx, %r1, %k + : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %r2 : !cheddar.ciphertext +} + +// HMult: explicit `rescale` flag is rendered as the trailing literal `true` +// or `false`. + +// CHECK: func.func @hmult_with_rescale +func.func @hmult_with_rescale(%ctx: !cheddar.context, %a: !cheddar.ciphertext, + %b: !cheddar.ciphertext, %k: !cheddar.eval_key) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HMult({}, {}, {}, {}, true);" + %r = cheddar.hmult %ctx, %a, %b, %k {rescale = true} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// CHECK: func.func @hmult_no_rescale +func.func @hmult_no_rescale(%ctx: !cheddar.context, %a: !cheddar.ciphertext, + %b: !cheddar.ciphertext, %k: !cheddar.eval_key) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HMult({}, {}, {}, {}, false);" + %r = cheddar.hmult %ctx, %a, %b, %k {rescale = false} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// HRot static-distance: the distance is baked into the format string twice, +// once for the key lookup and once for the rotation argument. + +// CHECK: func.func @hrot_static +func.func @hrot_static(%ctx: !cheddar.context, %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HRot({}, {}, {}->GetRotationKey(5), 5);" + %r = cheddar.hrot %ctx, %ct {static_distance = 5 : i64} + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// HRot dynamic-distance: the SSA distance value appears twice in the +// `args` list (once for the key, once for the rotation argument). + +// CHECK: func.func @hrot_dynamic +func.func @hrot_dynamic(%ctx: !cheddar.context, %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext, %d: index) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HRot({}, {}, {}->GetRotationKey({}), {});" + // CHECK-SAME: %arg3, %arg3 + %r = cheddar.hrot %ctx, %ct, %d + : (!cheddar.context, !cheddar.ciphertext, index) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// CHECK: func.func @hrot_add +func.func @hrot_add(%ctx: !cheddar.context, %ui: !cheddar.user_interface, + %a: !cheddar.ciphertext, %b: !cheddar.ciphertext) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HRotAdd({}, {}, {}, {}->GetRotationKey(7), 7);" + %r = cheddar.hrot_add %ctx, %a, %b {distance = 7 : i64} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// CHECK: func.func @hconj +func.func @hconj(%ctx: !cheddar.context, %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HConj({}, {}, {}->GetConjugationKey());" + %r = cheddar.hconj %ctx, %ct + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// CHECK: func.func @hconj_add +func.func @hconj_add(%ctx: !cheddar.context, %ui: !cheddar.user_interface, + %a: !cheddar.ciphertext, %b: !cheddar.ciphertext) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->HConjAdd({}, {}, {}, {}->GetConjugationKey());" + %r = cheddar.hconj_add %ctx, %a, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// MadUnsafe is in-place: no new variable is declared for the result; the +// SSA result aliases the accumulator input. Subsequent uses of `%r` resolve +// to the same C++ variable as `%acc`. + +// CHECK: func.func @mad +func.func @mad(%ctx: !cheddar.context, %acc: !cheddar.ciphertext, + %in: !cheddar.ciphertext, %c: !cheddar.constant) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->MadUnsafe({}, {}, {});" args %arg0, %arg1, %arg2, %arg3 + // CHECK-NOT: emitc.variable + %r = cheddar.mad_unsafe %ctx, %acc, %in, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// CHECK: func.func @boot +func.func @boot(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + // CHECK: emitc.verbatim "{}->Boot({}, {}, {});" + %r = cheddar.boot %ctx, %ct, %evk + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// Array attrs are inlined into the verbatim format string with doubled +// braces so that `emitc.verbatim` interprets them as literal `{` / `}` and +// renders the initializer list correctly. We check the substring rather +// than the full format string (FileCheck's `{{...}}` is a regex marker). + +// CHECK: func.func @linear_transform +func.func @linear_transform(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map, %d: tensor<2x4xf64>) + -> !cheddar.ciphertext { + // CHECK: emitc.verbatim + // CHECK-SAME: ->LinearTransform + // CHECK-SAME: 0, 1 + // CHECK-SAME: 5, 0 + %r = cheddar.linear_transform %ctx, %ct, %evk, %d + {diagonal_indices = array, level = 5 : i64, logBabyStepGiantStepRatio = 0 : i64} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map, tensor<2x4xf64>) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// CHECK: func.func @eval_poly +func.func @eval_poly(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + // CHECK: emitc.verbatim + // CHECK-SAME: ->EvalPoly + // CHECK-SAME: 1.0, 2.0, 3.0 + // CHECK-SAME: 4 + %r = cheddar.eval_poly %ctx, %ct, %evk + {coefficients = [1.0 : f64, 2.0 : f64, 3.0 : f64], level = 4 : i64} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} diff --git a/tests/Conversions/CheddarToEmitC/compile/BUILD b/tests/Conversions/CheddarToEmitC/compile/BUILD new file mode 100644 index 0000000000..13e3c4a8b2 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/compile/BUILD @@ -0,0 +1,69 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@heir//tools:heir-opt.bzl", "heir_opt") +load("@heir//tools:heir-translate.bzl", "heir_translate") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +# Compile-only guard that `cheddar-to-emitc` emits C++ which honours CHEDDAR's +# move/const contract. Unlike the GPU end-to-end tests under +# tests/Examples/cheddar, this compiles the emitted code against a header-only +# stub (cheddar_stub.h) with no CUDA, so it runs in normal CI. The cheddar +# dialect and pass are registered unconditionally in heir-opt/-translate, so +# this needs neither a GPU nor --//:enable_cheddar. + +heir_opt( + name = "kernels_emitc", + src = "kernels.mlir", + generated_filename = "kernels_emitc.mlir", + pass_flags = [ + "--cheddar-to-emitc", + "--reconcile-unrealized-casts", + ], +) + +heir_translate( + name = "kernels_cpp_raw", + src = ":kernels_emitc.mlir", + generated_filename = "kernels_raw.cc", + pass_flags = ["--mlir-to-cpp"], +) + +# Prepend the includes/usings the emitted body relies on. +genrule( + name = "kernels_lib_src", + srcs = [":kernels_raw.cc"], + outs = ["kernels_lib.cc"], + cmd = """cat > $@ <<'PRELUDE_EOF' +// AUTO-GENERATED: do not edit. See tests/Conversions/CheddarToEmitC/compile/BUILD. +#include +#include "tests/Conversions/CheddarToEmitC/compile/cheddar_stub.h" +using namespace cheddar; +using word = uint64_t; +PRELUDE_EOF +cat $(location :kernels_raw.cc) >> $@ +""", +) + +cc_library( + name = "cheddar_stub", + hdrs = ["cheddar_stub.h"], +) + +cc_library( + name = "kernels_compiled", + srcs = [":kernels_lib_src"], + deps = [":cheddar_stub"], +) + +# Compiling the generated C++ against the stub *is* the assertion: if the +# emitter produced C++ that violates CHEDDAR's move/const contract, this fails +# to build. build_test compiles without linking, so the stub only needs +# declarations (no method bodies), and no GPU/CUDA is involved. +build_test( + name = "compile_test", + targets = [":kernels_compiled"], +) diff --git a/tests/Conversions/CheddarToEmitC/compile/cheddar_stub.h b/tests/Conversions/CheddarToEmitC/compile/cheddar_stub.h new file mode 100644 index 0000000000..59dd26dfc4 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/compile/cheddar_stub.h @@ -0,0 +1,145 @@ +// Header-only stub of the CHEDDAR C++ API, used to *compile* (not run) the +// C++ that `cheddar-to-emitc` + `heir-translate --mlir-to-cpp` produce, with +// no GPU/CUDA toolchain. This is a CI-runnable guard that the emitted code +// honours CHEDDAR's move/const contract -- the kind of bug that FileCheck +// (which only inspects emitted text) cannot catch. +// +// The move/const semantics below mirror the real library (verified against +// CHEDDAR's include/core headers); only these properties matter here, so the +// method bodies are empty and the data layout is omitted: +// +// * Ciphertext/Plaintext/Constant/EvaluationKey -- move-only (copy deleted) +// *with* move-assignment, default-constructible. (core/Container.h) +// * EvkMap -- move-only, copy deleted, *no* move-assignment. (core/EvkMap.h) +// * Context::MadUnsafe(Ct& res, ...) mutates `res` in place, so `res` is a +// non-const reference. (core/Context.h:377) +// * UserInterface::Get*Key() / GetEvkMap() return `const&`. (UserInterface.h) +// +// Kept deliberately narrow: the "setup/getter" surface (create_context, +// create_user_interface, get_encoder, encode/decode) is *not* modelled here +// because that part of the emitter has independent, pre-existing mismatches +// against the real API and needs its own design pass. + +#ifndef TESTS_CONVERSIONS_CHEDDARTOEMITC_COMPILE_CHEDDAR_STUB_H_ +#define TESTS_CONVERSIONS_CHEDDARTOEMITC_COMPILE_CHEDDAR_STUB_H_ + +#include +#include + +namespace cheddar { + +// Move-only payload types with full move support (default + move-ctor + +// move-assign; copy deleted). +template +struct Ciphertext { + Ciphertext() = default; + Ciphertext(Ciphertext&&) = default; + Ciphertext& operator=(Ciphertext&&) = default; + Ciphertext(const Ciphertext&) = delete; + Ciphertext& operator=(const Ciphertext&) = delete; +}; +template +struct Plaintext { + Plaintext() = default; + Plaintext(Plaintext&&) = default; + Plaintext& operator=(Plaintext&&) = default; + Plaintext(const Plaintext&) = delete; + Plaintext& operator=(const Plaintext&) = delete; +}; +template +struct Constant { + Constant() = default; + Constant(Constant&&) = default; + Constant& operator=(Constant&&) = default; + Constant(const Constant&) = delete; + Constant& operator=(const Constant&) = delete; +}; +template +struct EvaluationKey { + EvaluationKey() = default; + EvaluationKey(EvaluationKey&&) = default; + EvaluationKey& operator=(EvaluationKey&&) = default; + EvaluationKey(const EvaluationKey&) = delete; + EvaluationKey& operator=(const EvaluationKey&) = delete; +}; + +// Move-only, and -- unlike the payload types -- has *no* move-assignment and +// is not default-constructible (the real EvkMap inherits std::unordered_map +// and declares only a move ctor). This is what makes the value+assign getter +// shape uncompilable, so the stub preserves it. +template +struct EvkMap { + EvkMap(EvkMap&&) = default; + EvkMap(const EvkMap&) = delete; + EvkMap& operator=(const EvkMap&) = delete; + + const EvaluationKey& GetRotationKey(int) const; + const EvaluationKey& GetConjugationKey() const; + const EvaluationKey& GetMultiplicationKey() const; +}; + +template +class UserInterface { + public: + using Ct = Ciphertext; + using Pt = Plaintext; + using Evk = EvaluationKey; + + void Encrypt(Ct& res, const Pt& a) const; + void Decrypt(Pt& res, const Ct& a) const; + + const Evk& GetRotationKey(int rot_idx) const; + const Evk& GetConjugationKey() const; + const Evk& GetMultiplicationKey() const; + const EvkMap& GetEvkMap() const; +}; + +template +class Context { + public: + using Ct = Ciphertext; + using Pt = Plaintext; + using Const = Constant; + using Evk = EvaluationKey; + + // Overloaded ct/pt/const arithmetic -- the emitter dispatches the dialect's + // *_plain / *_const ops to the base name and relies on C++ overloading. + void Add(Ct& res, const Ct& a, const Ct& b) const; + void Add(Ct& res, const Ct& a, const Pt& b) const; + void Add(Ct& res, const Ct& a, const Const& b) const; + void Sub(Ct& res, const Ct& a, const Ct& b) const; + void Sub(Ct& res, const Ct& a, const Pt& b) const; + void Mult(Ct& res, const Ct& a, const Ct& b) const; + void Mult(Ct& res, const Ct& a, const Pt& b) const; + void Mult(Ct& res, const Ct& a, const Const& b) const; + + void Neg(Ct& res, const Ct& a) const; + void Rescale(Ct& res, const Ct& a) const; + void Relinearize(Ct& res, const Ct& a, const Evk& key) const; + void RelinearizeRescale(Ct& res, const Ct& a, const Evk& key) const; + void LevelDown(Ct& res, const Ct& a, int target_level) const; + + void HMult(Ct& res, const Ct& a, const Ct& b, const Evk& mult_key, + bool rescale) const; + void HRot(Ct& res, const Ct& a, const Evk& rot_key, int rot_dist) const; + void HRotAdd(Ct& res, const Ct& a, const Ct& b, const Evk& rot_key, + int rot_dist) const; + void HConj(Ct& res, const Ct& a, const Evk& conj_key) const; + void HConjAdd(Ct& res, const Ct& a, const Ct& b, const Evk& conj_key) const; + + // In-place multiply-accumulate: `res` is mutated, so it is a *non-const* + // reference. This is the crux of the mad_unsafe finding. + void MadUnsafe(Ct& res, const Ct& a, const Const& b) const; + + void Boot(Ct& res, const Ct& a, const EvkMap& evk_map) const; + void LinearTransform(Ct& res, const Ct& a, const EvkMap& evk_map, + const std::vector& diagonals, + std::initializer_list diagonal_indices, int level, + int log_bsgs_ratio) const; + void EvalPoly(Ct& res, const Ct& a, const EvkMap& evk_map, + std::initializer_list coefficients, int level) const; +}; + +} // namespace cheddar + +#endif // TESTS_CONVERSIONS_CHEDDARTOEMITC_COMPILE_CHEDDAR_STUB_H_ diff --git a/tests/Conversions/CheddarToEmitC/compile/kernels.mlir b/tests/Conversions/CheddarToEmitC/compile/kernels.mlir new file mode 100644 index 0000000000..a39b60dccb --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/compile/kernels.mlir @@ -0,0 +1,158 @@ +// Input for the cheddar-to-emitc *compile* test (see BUILD): every function +// here is lowered to C++ and compiled against cheddar_stub.h. The point is to +// exercise the emitter's move/const handling on the op surface that real +// kernels use, with ctx / user_interface / keys / evk_map taken as function +// arguments (the shape a CKKS-to-Cheddar lowering produces). +// +// The setup/getter ops (create_context, create_user_interface, get_encoder, +// get_evk_map, get_mult_key, encode/decode) are intentionally absent: that +// part of the emitter has independent, pre-existing API mismatches and is +// tracked separately. + +// Add / Sub / Mult chained on ciphertexts. +func.func @arith(%ctx: !cheddar.context, %a: !cheddar.ciphertext, + %b: !cheddar.ciphertext) -> !cheddar.ciphertext { + %0 = cheddar.add %ctx, %a, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %1 = cheddar.sub %ctx, %0, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %2 = cheddar.mult %ctx, %1, %a + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %2 : !cheddar.ciphertext +} + +// ct+pt and ct+const overloaded dispatch. +func.func @ct_pt_const(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext, %c: !cheddar.constant) + -> !cheddar.ciphertext { + %0 = cheddar.add_plain %ctx, %ct, %pt + : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + %1 = cheddar.sub_plain %ctx, %0, %pt + : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + %2 = cheddar.mult_plain %ctx, %1, %pt + : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + %3 = cheddar.add_const %ctx, %2, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + %4 = cheddar.mult_const %ctx, %3, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %4 : !cheddar.ciphertext +} + +// Unary ops. +func.func @unary(%ctx: !cheddar.context, %ct: !cheddar.ciphertext) + -> !cheddar.ciphertext { + %0 = cheddar.neg %ctx, %ct + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + %1 = cheddar.rescale %ctx, %0 + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + %2 = cheddar.level_down %ctx, %1 {targetLevel = 2 : i64} + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %2 : !cheddar.ciphertext +} + +// Relinearize / RelinearizeRescale with an evaluation-key argument. +func.func @relin(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %k: !cheddar.eval_key) -> !cheddar.ciphertext { + %0 = cheddar.relinearize %ctx, %ct, %k + : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + %1 = cheddar.relinearize_rescale %ctx, %0, %k + : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %1 : !cheddar.ciphertext +} + +// HMult with an evaluation-key argument. +func.func @hmult(%ctx: !cheddar.context, %a: !cheddar.ciphertext, + %b: !cheddar.ciphertext, %k: !cheddar.eval_key) + -> !cheddar.ciphertext { + %0 = cheddar.hmult %ctx, %a, %b, %k {rescale = true} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) + -> !cheddar.ciphertext + return %0 : !cheddar.ciphertext +} + +// Rotation / conjugation: the key is looked up inline via the UserInterface +// argument, so these functions must carry a user_interface arg. +func.func @rotations(%ctx: !cheddar.context, %ui: !cheddar.user_interface, + %a: !cheddar.ciphertext, %b: !cheddar.ciphertext) + -> !cheddar.ciphertext { + %0 = cheddar.hrot %ctx, %a {static_distance = 5 : i64} + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + %1 = cheddar.hrot_add %ctx, %0, %b {distance = 7 : i64} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %2 = cheddar.hconj %ctx, %1 + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + %3 = cheddar.hconj_add %ctx, %2, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %3 : !cheddar.ciphertext +} + +// mad_unsafe with a *local* accumulator (the result of add): the accumulator +// is a plain local lvalue, so MadUnsafe(acc, ...) binds fine. This path +// already compiled; it's here as the control case. +func.func @mad_local(%ctx: !cheddar.context, %a: !cheddar.ciphertext, + %b: !cheddar.ciphertext, %c: !cheddar.constant) + -> !cheddar.ciphertext { + %acc = cheddar.add %ctx, %a, %b + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %r = cheddar.mad_unsafe %ctx, %acc, %a, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// mad_unsafe with the accumulator coming straight from a *function argument* +// (finding 1). The accumulator is mutated in place by MadUnsafe and then +// returned, so it must be lifted to a mutable `Ct&` -- not the `const Ct&` +// that the by-value-arg tightening would otherwise produce. +func.func @mad_arg(%ctx: !cheddar.context, %acc: !cheddar.ciphertext, + %in: !cheddar.ciphertext, %c: !cheddar.constant) + -> !cheddar.ciphertext { + %r = cheddar.mad_unsafe %ctx, %acc, %in, %c + : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) + -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} + +// Identity: returning a move-only argument unchanged must lift the arg to an +// in-place `Ct&` out-param, not copy. +func.func @identity(%ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + return %ct : !cheddar.ciphertext +} + +// Bootstrapping-family ops taking an EvkMap argument (const EvkMap& at the C++ +// boundary). +func.func @boot(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + %0 = cheddar.boot %ctx, %ct, %evk + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %0 : !cheddar.ciphertext +} + +func.func @linear_transform(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map, %d: tensor<2x4xf64>) + -> !cheddar.ciphertext { + %0 = cheddar.linear_transform %ctx, %ct, %evk, %d + {diagonal_indices = array, level = 5 : i64, logBabyStepGiantStepRatio = 0 : i64} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map, tensor<2x4xf64>) + -> !cheddar.ciphertext + return %0 : !cheddar.ciphertext +} + +func.func @eval_poly(%ctx: !cheddar.context, %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + %0 = cheddar.eval_poly %ctx, %ct, %evk + {coefficients = [1.0 : f64, 2.0 : f64, 3.0 : f64], level = 4 : i64} + : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %0 : !cheddar.ciphertext +} + +// Encrypt / Decrypt out-param calls on the UserInterface. +func.func @encrypt_decrypt(%ui: !cheddar.user_interface, %pt: !cheddar.plaintext, + %ct: !cheddar.ciphertext) + -> (!cheddar.ciphertext, !cheddar.plaintext) { + %0 = cheddar.encrypt %ui, %pt + : (!cheddar.user_interface, !cheddar.plaintext) -> !cheddar.ciphertext + %1 = cheddar.decrypt %ui, %ct + : (!cheddar.user_interface, !cheddar.ciphertext) -> !cheddar.plaintext + return %0, %1 : !cheddar.ciphertext, !cheddar.plaintext +} diff --git a/tests/Conversions/CheddarToEmitC/no_ui.mlir b/tests/Conversions/CheddarToEmitC/no_ui.mlir new file mode 100644 index 0000000000..9392e5d1f6 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/no_ui.mlir @@ -0,0 +1,16 @@ +// RUN: heir-opt --cheddar-to-emitc --verify-diagnostics %s + +// HRot/HRotAdd/HConj/HConjAdd discover the UserInterface from the enclosing +// function's argument list at lowering time. Functions that lack such an +// argument cannot be legalized. + +func.func @hrot_without_ui(%ctx: !cheddar.context, %ct: !cheddar.ciphertext) + -> !cheddar.ciphertext { + // The conversion framework emits "failed to legalize"; whether the + // pattern's own diagnostic ("enclosing function is missing UserInterface + // arg") is also surfaced depends on MLIR's internal dispatch. + // expected-error@+1 {{'cheddar.hrot'}} + %r = cheddar.hrot %ctx, %ct {static_distance = 1 : i64} + : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %r : !cheddar.ciphertext +} diff --git a/tests/Conversions/CheddarToEmitC/unsupported_getters.mlir b/tests/Conversions/CheddarToEmitC/unsupported_getters.mlir new file mode 100644 index 0000000000..c429f4d261 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/unsupported_getters.mlir @@ -0,0 +1,39 @@ +// RUN: heir-opt --cheddar-to-emitc --split-input-file --verify-diagnostics %s + +// The getter-style setup ops return a const reference to a move-only / +// non-assignable CHEDDAR value (EvkMap, EvaluationKey, Encoder, UserInterface), +// which can't be materialised into a local without a copy. The lowering +// rejects them rather than emit uncompilable C++. Real kernels pass these as +// function arguments or look them up inline (like HRot's rotation-key lookup). + +func.func @get_evk_map(%ui: !cheddar.user_interface) -> !cheddar.evk_map { + // expected-error @below {{lowering of 'cheddar.get_evk_map' is not supported}} + %m = cheddar.get_evk_map %ui : (!cheddar.user_interface) -> !cheddar.evk_map + return %m : !cheddar.evk_map +} + +// ----- + +func.func @get_mult_key(%ui: !cheddar.user_interface) -> !cheddar.eval_key { + // expected-error @below {{lowering of 'cheddar.get_mult_key' is not supported}} + %k = cheddar.get_mult_key %ui : (!cheddar.user_interface) -> !cheddar.eval_key + return %k : !cheddar.eval_key +} + +// ----- + +func.func @get_encoder(%ctx: !cheddar.context) -> !cheddar.encoder { + // expected-error @below {{lowering of 'cheddar.get_encoder' is not supported}} + %e = cheddar.get_encoder %ctx : (!cheddar.context) -> !cheddar.encoder + return %e : !cheddar.encoder +} + +// ----- + +func.func @create_user_interface(%ctx: !cheddar.context) + -> !cheddar.user_interface { + // expected-error @below {{lowering of 'cheddar.create_user_interface' is not supported}} + %ui = cheddar.create_user_interface %ctx + : (!cheddar.context) -> !cheddar.user_interface + return %ui : !cheddar.user_interface +} diff --git a/tests/Conversions/CheddarToEmitC/unsupported_move_only_memref.mlir b/tests/Conversions/CheddarToEmitC/unsupported_move_only_memref.mlir new file mode 100644 index 0000000000..c356a63924 --- /dev/null +++ b/tests/Conversions/CheddarToEmitC/unsupported_move_only_memref.mlir @@ -0,0 +1,24 @@ +// RUN: heir-opt --cheddar-to-emitc --split-input-file --verify-diagnostics %s + +// A dynamic-shape memref of a move-only cheddar type can't be represented as a +// fixed-size emitc.array, so the type converter refuses it and the conversion +// fails -- rather than falling through to a memref that stock +// MemRefToEmitC would lower with copies of move-only payloads. +// expected-error @below {{failed to legalize operation 'func.func'}} +func.func @dynamic(%m: memref, %i: index) + -> !cheddar.ciphertext { + %0 = memref.load %m[%i] : memref + return %0 : !cheddar.ciphertext +} + +// ----- + +// A rank>1 memref of a move-only cheddar type would need a multi-dimensional +// emitc.array, which the destination-passing boundary lift (1-D std::array) +// can't represent, so the conversion fails rather than emit invalid C++. +// expected-error @below {{failed to legalize operation 'func.func'}} +func.func @rank2(%m: memref<2x3x!cheddar.ciphertext>, %i: index, %j: index) + -> !cheddar.ciphertext { + %0 = memref.load %m[%i, %j] : memref<2x3x!cheddar.ciphertext> + return %0 : !cheddar.ciphertext +} diff --git a/tools/BUILD b/tools/BUILD index 4e4f28f86b..14215e405d 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -44,6 +44,7 @@ cc_binary( "@heir//lib/Analysis/NoiseAnalysis", # buildcleaner: keep "@heir//lib/Analysis/NoiseAnalysis/BFV:NoiseAnalysis", # buildcleaner: keep "@heir//lib/Analysis/NoiseAnalysis/BGV:NoiseAnalysis", # buildcleaner: keep + "@heir//lib/Conversions/CheddarToEmitC", "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Dialect/Arith/Conversions/ArithToCGGI", "@heir//lib/Dialect/Arith/Conversions/ArithToCGGIQuart", @@ -185,6 +186,7 @@ cc_binary( "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:AffineTransforms", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToEmitC", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:ArithValueBoundsOpInterfaceImpl", @@ -209,6 +211,7 @@ cc_binary( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToEmitC", "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:MlirOptLib", @@ -216,6 +219,7 @@ cc_binary( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SCFToEmitC", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", @@ -248,6 +252,8 @@ cc_binary( "@heir//lib/Target/TfheRustHL:TfheRustHLEmitter", "@heir//lib/Target/Verilog:VerilogEmitter", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllTranslations", + "@llvm-project//mlir:TargetCpp", "@llvm-project//mlir:TranslateLib", ], ) diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index d35cd4115e..dc8d720718 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -3,6 +3,7 @@ #include #include +#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h" #include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h" #include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h" #include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" @@ -125,6 +126,8 @@ #include "lib/Transforms/UnusedMemRef/UnusedMemRef.h" #include "lib/Transforms/ValidateNoise/ValidateNoise.h" #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project @@ -133,9 +136,12 @@ #include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h" // from @llvm-project #include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -278,6 +284,21 @@ int main(int argc, char** argv) { registerPass( []() -> std::unique_ptr { return createConvertToLLVMPass(); }); + // SCF/MemRef/Arith -> EmitC, used by the cheddar pipeline to lower + // bufferized loop kernels through EmitC to C++. + registerPass([]() -> std::unique_ptr { return createSCFToEmitC(); }); + registerPass( + []() -> std::unique_ptr { return createConvertMemRefToEmitC(); }); + registerPass( + []() -> std::unique_ptr { return createConvertArithToEmitC(); }); + mlir::registerConvertSCFToEmitCInterface(registry); + mlir::registerConvertMemRefToEmitCInterface(registry); + mlir::registerConvertArithToEmitCInterface(registry); + // Attaches MemRefElementTypeInterface to emitc::OpaqueType so the + // cheddar-to-emitc type converter can form memref> + // as the converted form of memref. + mlir::heir::registerCheddarToEmitCExternalModels(registry); + // Bufferization and external models bufferization::registerBufferizationPasses(); mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); @@ -302,6 +323,7 @@ int main(int argc, char** argv) { registerEmitCInterfacePass(); cggi::registerCGGIPasses(); debug::registerDebugPasses(); + registerCheddarToEmitCPasses(); ckks::registerCKKSPasses(); lattigo::registerLattigoPasses(); lwe::registerLWEPasses(); diff --git a/tools/heir-translate.cpp b/tools/heir-translate.cpp index c4e73daae8..3d45613dad 100644 --- a/tools/heir-translate.cpp +++ b/tools/heir-translate.cpp @@ -12,9 +12,14 @@ #include "lib/Target/TfheRustHL/TfheRustHLEmitter.h" #include "lib/Target/Verilog/VerilogEmitter.h" #include "llvm/include/llvm/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/InitAllTranslations.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/MlirTranslateMain.h" // from @llvm-project int main(int argc, char** argv) { + // MLIR-to-C++ via the EmitC dialect: used as the final stage of the CHEDDAR + // backend (`scheme-to-cheddar | --cheddar-to-emitc | heir-translate + // --mlir-to-cpp`), and available standalone for any EmitC IR. + mlir::registerToCppTranslation(); // Verilog output mlir::heir::registerToVerilogTranslation(); mlir::heir::registerMetadataEmitter();