diff --git a/lib/Dialect/Cheddar/IR/BUILD b/lib/Dialect/Cheddar/IR/BUILD new file mode 100644 index 0000000000..5b8f86d1f7 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/BUILD @@ -0,0 +1,122 @@ +# Cheddar dialect implementation + +load("@heir//lib/Dialect:dialect.bzl", "add_heir_dialect_library") +load("@llvm-project//mlir:tblgen.bzl", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = [ + "CheddarDialect.cpp", + ], + hdrs = [ + "CheddarDialect.h", + "CheddarOps.h", + "CheddarTypes.h", + ], + deps = [ + ":CheddarOps", + ":CheddarTypes", + ":dialect_inc_gen", + ":ops_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) + +cc_library( + name = "CheddarTypes", + srcs = [ + "CheddarTypes.cpp", + ], + hdrs = [ + "CheddarDialect.h", + "CheddarTypes.h", + ], + deps = [ + ":dialect_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "CheddarOps", + srcs = [ + "CheddarOps.cpp", + ], + hdrs = [ + "CheddarDialect.h", + "CheddarOps.h", + "CheddarTypes.h", + ], + deps = [ + ":CheddarTypes", + ":dialect_inc_gen", + ":ops_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", + "@heir//lib/Utils", + "@heir//lib/Utils:RotationUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", + ], +) + +td_library( + name = "td_files", + srcs = [ + "CheddarDialect.td", + "CheddarOps.td", + "CheddarTypes.td", + ], + # include from the heir-root to enable fully-qualified include-paths + includes = ["../../../.."], + deps = [ + "@heir//lib/Dialect:td_files", + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +add_heir_dialect_library( + name = "dialect_inc_gen", + dialect = "Cheddar", + kind = "dialect", + td_file = "CheddarDialect.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "types_inc_gen", + dialect = "Cheddar", + kind = "type", + td_file = "CheddarTypes.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "ops_inc_gen", + dialect = "Cheddar", + kind = "op", + td_file = "CheddarOps.td", + deps = [ + ":td_files", + "@heir//lib/Dialect:td_files", + ], +) diff --git a/lib/Dialect/Cheddar/IR/CheddarDialect.cpp b/lib/Dialect/Cheddar/IR/CheddarDialect.cpp new file mode 100644 index 0000000000..c78f2cc186 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarDialect.cpp @@ -0,0 +1,39 @@ +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" + +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project + +// NOLINTNEXTLINE(misc-include-cleaner): Required to define CheddarOps + +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" + +// Generated definitions +#include "lib/Dialect/Cheddar/IR/CheddarDialect.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc" + +namespace mlir { +namespace heir { +namespace cheddar { + +void CheddarDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "lib/Dialect/Cheddar/IR/CheddarTypes.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc" + >(); +} + +} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarDialect.h b/lib/Dialect/Cheddar/IR/CheddarDialect.h new file mode 100644 index 0000000000..555fe2ad55 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarDialect.h @@ -0,0 +1,10 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_ + +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project + +// Generated headers (block clang-format from messing up order) +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarDialect.td b/lib/Dialect/Cheddar/IR/CheddarDialect.td new file mode 100644 index 0000000000..1871a612da --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarDialect.td @@ -0,0 +1,24 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_ + +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" + +def Cheddar_Dialect : Dialect { + let name = "cheddar"; + let description = [{ + The `cheddar` dialect is an exit dialect for generating C++ code against the + CHEDDAR GPU FHE library API. + + CHEDDAR is a CKKS-only GPU-accelerated FHE library. It supports both 32-bit + and 64-bit word types, with 32-bit being the primary fast path on GPUs. + + See [the Cheddar GitHub repository](https://github.com/scale-snu/cheddar-fhe) + }]; + + let cppNamespace = "::mlir::heir::cheddar"; + + let useDefaultTypePrinterParser = 1; +} + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_ diff --git a/lib/Dialect/Cheddar/IR/CheddarOps.cpp b/lib/Dialect/Cheddar/IR/CheddarOps.cpp new file mode 100644 index 0000000000..42f475b3da --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarOps.cpp @@ -0,0 +1,46 @@ +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" + +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Utils/RotationUtils.h" +#include "lib/Utils/Utils.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace cheddar { + +::llvm::SmallVector<::mlir::OpFoldResult> HRotOp::getRotationIndices() { + if (getStaticDistance()) return {getStaticDistanceAttr()}; + return {getDynamicDistance()}; +} + +LogicalResult HRotOp::verify() { + return containsExactlyOneOrEmitError(getOperation(), getDynamicDistance(), + getStaticDistance()); +} + +::llvm::SmallVector<::mlir::OpFoldResult> HRotAddOp::getRotationIndices() { + return {getDistanceAttr()}; +} + +::llvm::SmallVector<::mlir::OpFoldResult> +LinearTransformOp::getRotationIndices() { + auto diagonalsType = cast(getDiagonals().getType()); + int64_t slots = diagonalsType.getShape()[1]; + int64_t logBSGS = getLogBabyStepGiantStepRatio().getInt(); + auto rotations = lintransRotationIndices( + getDiagonalIndicesAttr().asArrayRef(), slots, logBSGS); + SmallVector result; + result.reserve(rotations.size()); + auto* mlirCtx = (*this)->getContext(); + for (int64_t rot : rotations) { + result.push_back(IntegerAttr::get(IndexType::get(mlirCtx), rot)); + } + return result; +} + +} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarOps.h b/lib/Dialect/Cheddar/IR/CheddarOps.h new file mode 100644 index 0000000000..4367406e22 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarOps.h @@ -0,0 +1,15 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_ + +// IWYU pragma: begin_keep +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Dialect/HEIRInterfaces.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +// IWYU pragma: end_keep + +#define GET_OP_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarOps.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarOps.td b/lib/Dialect/Cheddar/IR/CheddarOps.td new file mode 100644 index 0000000000..54c7466f6e --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarOps.td @@ -0,0 +1,475 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_TD_ + +include "CheddarDialect.td" +include "CheddarTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "lib/Dialect/HEIRInterfaces.td" + +class Cheddar_Op traits = []> : + Op { + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Setup operations +//===----------------------------------------------------------------------===// + +def Cheddar_CreateContextOp : Cheddar_Op<"create_context"> { + let summary = "Create a CHEDDAR context from parameters"; + let description = [{ + Creates a CHEDDAR Context from a Parameter object. + The context is the main server-side computation engine. + }]; + let arguments = (ins Cheddar_Parameter:$params); + let results = (outs Cheddar_Context:$ctx); +} + +def Cheddar_CreateUserInterfaceOp : Cheddar_Op<"create_user_interface"> { + let summary = "Create a CHEDDAR UserInterface for key gen and encrypt/decrypt"; + let description = [{ + Creates a UserInterface from a context. The UserInterface handles + key generation, encryption, and decryption. Note: for test purposes only. + }]; + let arguments = (ins Cheddar_Context:$ctx); + let results = (outs Cheddar_UserInterface:$ui); +} + +def Cheddar_GetEncoderOp : Cheddar_Op<"get_encoder"> { + let summary = "Get the encoder from a CHEDDAR context"; + let description = [{ + Returns a reference to the context's encoder (context->encoder_). + }]; + let arguments = (ins Cheddar_Context:$ctx); + let results = (outs Cheddar_Encoder:$encoder); +} + +def Cheddar_GetEvkMapOp : Cheddar_Op<"get_evk_map"> { + let summary = "Get the evaluation key map from a UserInterface"; + let description = [{ + Returns the EvkMap from the UserInterface (ui.GetEvkMap()). + }]; + let arguments = (ins Cheddar_UserInterface:$ui); + let results = (outs Cheddar_EvkMap:$evkMap); +} + +def Cheddar_GetMultKeyOp : Cheddar_Op<"get_mult_key"> { + let summary = "Get the multiplication evaluation key"; + let description = [{ + Returns the multiplication key from the UserInterface + (ui.GetMultiplicationKey()). + }]; + let arguments = (ins Cheddar_UserInterface:$ui); + let results = (outs Cheddar_EvalKey:$key); +} + +def Cheddar_PrepareRotKeyOp : Cheddar_Op<"prepare_rot_key"> { + let summary = "Generate a rotation key for a given distance"; + let description = [{ + Calls ui.PrepareRotationKey(distance, max_level) to generate a rotation key. + Must be called before using rotation with that distance. + The max_level parameter specifies the maximum ciphertext level at which + this rotation key will be used. + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Builtin_IntegerAttr:$distance, + Builtin_IntegerAttr:$maxLevel + ); + let results = (outs); +} + +//===----------------------------------------------------------------------===// +// Encode / Encrypt / Decrypt operations +//===----------------------------------------------------------------------===// + +def Cheddar_EncodeOp : Cheddar_Op<"encode"> { + let summary = "Encode a message vector into a CHEDDAR plaintext"; + let description = [{ + Calls encoder.Encode(pt, level, scale, message). The message is a vector + of complex numbers (or reals). `scale` is the C++ `double` value used by + CHEDDAR's Encoder. + }]; + let arguments = (ins + Cheddar_Encoder:$encoder, + RankedTensorOf<[AnyFloat, AnyComplex]>:$message, + Builtin_IntegerAttr:$level, + F64Attr:$scale + ); + let results = (outs Cheddar_Plaintext:$plaintext); +} + +def Cheddar_EncodeConstantOp : Cheddar_Op<"encode_constant"> { + let summary = "Encode a scalar double into a CHEDDAR constant"; + let description = [{ + Calls encoder.EncodeConstant(constant, level, scale, number). + The result is in RNS form for efficient ciphertext-scalar ops. + `scale` is the C++ `double` value used by CHEDDAR's Encoder. + }]; + let arguments = (ins + Cheddar_Encoder:$encoder, + AnyFloat:$value, + Builtin_IntegerAttr:$level, + F64Attr:$scale + ); + let results = (outs Cheddar_Constant:$constant); +} + +def Cheddar_DecodeOp : Cheddar_Op<"decode"> { + let summary = "Decode a CHEDDAR plaintext back to a message vector"; + let description = [{ + Calls encoder.Decode(message, pt). Returns a vector of complex numbers. + }]; + let arguments = (ins + Cheddar_Encoder:$encoder, + Cheddar_Plaintext:$plaintext + ); + let results = (outs RankedTensorOf<[AnyFloat, AnyComplex]>:$message); +} + +def Cheddar_EncryptOp : Cheddar_Op<"encrypt"> { + let summary = "Encrypt a plaintext into a ciphertext"; + let description = [{ + Calls ui.Encrypt(ct, pt). Test-only operation. + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Cheddar_Plaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$ciphertext); +} + +def Cheddar_DecryptOp : Cheddar_Op<"decrypt"> { + let summary = "Decrypt a ciphertext into a plaintext"; + let description = [{ + Calls ui.Decrypt(pt, ct). Test-only operation. + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Cheddar_Ciphertext:$ciphertext + ); + let results = (outs Cheddar_Plaintext:$plaintext); +} + +//===----------------------------------------------------------------------===// +// Arithmetic operations +//===----------------------------------------------------------------------===// + +class Cheddar_BinaryCtCtOp traits = []> + : Cheddar_Op { + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$lhs, + Cheddar_Ciphertext:$rhs + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_AddOp : Cheddar_BinaryCtCtOp<"add"> { + let summary = "Add two ciphertexts"; + let description = [{ + Calls context->Add(res, a, b). + }]; +} + +def Cheddar_SubOp : Cheddar_BinaryCtCtOp<"sub"> { + let summary = "Subtract two ciphertexts"; + let description = [{ + Calls context->Sub(res, a, b). + }]; +} + +def Cheddar_MultOp : Cheddar_BinaryCtCtOp<"mult", [IncreasesMulDepthOpInterface]> { + let summary = "Multiply two ciphertexts (tensor product, no relin/rescale)"; + let description = [{ + Calls context->Mult(res, a, b). Produces a degree-3 ciphertext. + Does NOT include relinearization or rescaling. + }]; +} + +def Cheddar_AddPlainOp : Cheddar_Op<"add_plain"> { + let summary = "Add a plaintext to a ciphertext"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Plaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_SubPlainOp : Cheddar_Op<"sub_plain"> { + let summary = "Subtract a plaintext from a ciphertext"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Plaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_MultPlainOp : Cheddar_Op<"mult_plain"> { + let summary = "Multiply a ciphertext by a plaintext (no rescale)"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Plaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_AddConstOp : Cheddar_Op<"add_const"> { + let summary = "Add a constant to a ciphertext"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Constant:$constant + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_MultConstOp : Cheddar_Op<"mult_const"> { + let summary = "Multiply a ciphertext by a constant (no rescale)"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Constant:$constant + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_NegOp : Cheddar_Op<"neg"> { + let summary = "Negate a ciphertext"; + let description = [{ + Calls context->Neg(res, a). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_RescaleOp : Cheddar_Op<"rescale"> { + let summary = "Rescale a ciphertext (drop one level)"; + let description = [{ + Calls context->Rescale(res, a). Reduces the ciphertext level by 1. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_LevelDownOp : Cheddar_Op<"level_down"> { + let summary = "Reduce ciphertext to a target level"; + let description = [{ + Calls context->LevelDown(res, a, target_level). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Builtin_IntegerAttr:$targetLevel + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Key-switching operations +//===----------------------------------------------------------------------===// + +def Cheddar_RelinearizeOp : Cheddar_Op<"relinearize"> { + let summary = "Relinearize a ciphertext (without rescale)"; + let description = [{ + Calls context->Relinearize(res, a, mult_key). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvalKey:$multKey + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_RelinearizeRescaleOp : Cheddar_Op<"relinearize_rescale"> { + let summary = "Fused relinearize + rescale"; + let description = [{ + Calls context->RelinearizeRescale(res, a, mult_key). Faster than + separate Relinearize + Rescale. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvalKey:$multKey + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Compound (fused) operations -- high-performance GPU kernels +//===----------------------------------------------------------------------===// + +def Cheddar_HMultOp : Cheddar_Op<"hmult", [IncreasesMulDepthOpInterface]> { + let summary = "Fused multiply + relinearize (+ optional rescale)"; + let description = [{ + Calls context->HMult(res, a, b, mult_key, rescale). + Single fused GPU kernel launch. The `rescale` attribute controls whether + rescaling is included (default: true). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$lhs, + Cheddar_Ciphertext:$rhs, + Cheddar_EvalKey:$multKey, + DefaultValuedAttr:$rescale + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_HRotOp : Cheddar_Op<"hrot", [ + DeclareOpInterfaceMethods +]> { + let summary = "Fused key-switch + rotation"; + let description = [{ + Calls context->HRot(res, a, rot_key, distance). The rotation key is looked + up from the enclosing function's UserInterface using `distance`. + Single fused GPU kernel. Supports both static and dynamic distances. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Optional:$dynamic_distance, + OptionalAttr:$static_distance + ); + let results = (outs Cheddar_Ciphertext:$output); + let hasVerifier = 1; +} + +def Cheddar_HRotAddOp : Cheddar_Op<"hrot_add", [ + DeclareOpInterfaceMethods +]> { + let summary = "Fused rotation + addition"; + let description = [{ + Computes res = rotate(a, distance) + b in a single fused GPU kernel. + Calls context->HRotAdd(res, a, b, rot_key, distance). The rotation key + is looked up from the enclosing function's UserInterface using `distance`. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_Ciphertext:$addend, + Builtin_IntegerAttr:$distance + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_HConjOp : Cheddar_Op<"hconj"> { + let summary = "Fused key-switch + conjugation"; + let description = [{ + Calls context->HConj(res, a, conj_key). The conjugation key is looked + up from the enclosing function's UserInterface. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_HConjAddOp : Cheddar_Op<"hconj_add"> { + let summary = "Fused conjugation + addition"; + let description = [{ + Computes res = conj(a) + b in a single fused GPU kernel. + Calls context->HConjAdd(res, a, b, conj_key). The conjugation key is + looked up from the enclosing function's UserInterface. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_Ciphertext:$addend + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_MadUnsafeOp : Cheddar_Op<"mad_unsafe"> { + let summary = "Fused multiply-accumulate with constant (no rescale)"; + let description = [{ + Computes res += a * constant (in-place accumulation). + Calls context->MadUnsafe(res, a, constant). + + Note: this op would typically be called `mac`, + the current name reflects the spelling in Cheddar + (`MadUnsafe`). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$accumulator, + Cheddar_Ciphertext:$input, + Cheddar_Constant:$constant + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Extension operations (bootstrapping, linear transforms, poly eval) +//===----------------------------------------------------------------------===// + +def Cheddar_BootOp : Cheddar_Op<"boot", [ResetsMulDepthOpInterface]> { + let summary = "Bootstrap a ciphertext"; + let description = [{ + Calls boot_ctx->Boot(res, input, evk_map). + Refreshes the ciphertext noise budget. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvkMap:$evkMap + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_LinearTransformOp : Cheddar_Op<"linear_transform", [ + DeclareOpInterfaceMethods +]> { + let summary = "Apply a linear transform on a ciphertext"; + let description = [{ + Applies a matrix-vector product using CHEDDAR's LinearTransform extension + with BSGS optimization and hoisting. + + The `diagonals` input is a 2D tensor where each row is a non-zero diagonal. + The `diagonal_indices` attribute specifies which diagonal each row represents. + The `level` attribute specifies the modulus level for the operation. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvkMap:$evkMap, + 2DTensorOf<[AnyFloat]>:$diagonals, + DenseI32ArrayAttr:$diagonal_indices, + Builtin_IntegerAttr:$level, + Builtin_IntegerAttr:$logBabyStepGiantStepRatio + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_EvalPolyOp : Cheddar_Op<"eval_poly"> { + let summary = "Evaluate a polynomial on a ciphertext"; + let description = [{ + Evaluates a polynomial (e.g., Chebyshev approximation) on an encrypted + input using CHEDDAR's EvalPoly extension. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvkMap:$evkMap, + ArrayAttr:$coefficients, + Builtin_IntegerAttr:$level + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_TD_ diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.cpp b/lib/Dialect/Cheddar/IR/CheddarTypes.cpp new file mode 100644 index 0000000000..707851f21a --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.cpp @@ -0,0 +1,7 @@ +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" + +namespace mlir { +namespace heir { +namespace cheddar {} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.h b/lib/Dialect/Cheddar/IR/CheddarTypes.h new file mode 100644 index 0000000000..633882c52e --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.h @@ -0,0 +1,13 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_H_ + +// IWYU pragma: begin_keep +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/HEIRInterfaces.h" +#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project +// IWYU pragma: end_keep + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.td b/lib/Dialect/Cheddar/IR/CheddarTypes.td new file mode 100644 index 0000000000..5f8b0b795f --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.td @@ -0,0 +1,107 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_TD_ + +include "CheddarDialect.td" + +include "lib/Dialect/HEIRInterfaces.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpAsmInterface.td" + +// A base class for all types in this dialect +class Cheddar_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; + let genMnemonicAlias = 1; + + string asmName = ?; + let extraClassDeclaration = [{ + // OpAsmTypeInterface method + void getAsmName(::mlir::OpAsmSetNameFn setNameFn) const { + setNameFn("}] # asmName # [{"); + } + }]; +} + +// Context types + +def Cheddar_Context : Cheddar_Type<"Context", "context"> { + let description = [{ + This type represents a CHEDDAR Context (or BootContext), which is the + main server-side computation engine. Created via Context::Create(param) + or BootContext::Create(param, boot_param). + }]; + let asmName = "ctx"; +} + +def Cheddar_Parameter : Cheddar_Type<"Parameter", "parameter"> { + let description = [{ + This type represents a CHEDDAR Parameter object, constructed from a JSON + file or programmatically. + }]; + let asmName = "param"; +} + +def Cheddar_Encoder : Cheddar_Type<"Encoder", "encoder"> { + let description = [{ + This type represents the CHEDDAR Encoder, accessed via context->encoder_. + Used for encoding/decoding plaintext values. + }]; + let asmName = "encoder"; +} + +def Cheddar_UserInterface : Cheddar_Type<"UserInterface", "user_interface"> { + let description = [{ + This type represents the CHEDDAR UserInterface, used for key generation, + encryption, and decryption. Note: this is for test purposes only and is + not security-hardened. + }]; + let asmName = "ui"; +} + +// Data types + +def Cheddar_Ciphertext : Cheddar_Type<"Ciphertext", "ciphertext", [SecretTypeInterface]> { + let description = [{ + This type represents a CHEDDAR Ciphertext. Move-only, lives on GPU. + }]; + let asmName = "ct"; +} + +def Cheddar_Plaintext : Cheddar_Type<"Plaintext", "plaintext"> { + let description = [{ + This type represents a CHEDDAR Plaintext. Contains an NTT-applied + encoded message. + }]; + let asmName = "pt"; +} + +def Cheddar_Constant : Cheddar_Type<"Constant", "constant"> { + let description = [{ + This type represents a CHEDDAR Constant. A scalar in RNS form, + used for efficient ciphertext-scalar operations. + }]; + let asmName = "const"; +} + +// Key types + +def Cheddar_EvalKey : Cheddar_Type<"EvalKey", "eval_key"> { + let description = [{ + This type represents a single CHEDDAR EvaluationKey. + }]; + let asmName = "evk"; +} + +def Cheddar_EvkMap : Cheddar_Type<"EvkMap", "evk_map"> { + let description = [{ + This type represents a CHEDDAR EvkMap, which bundles all evaluation + keys (multiplication, rotation, conjugation, etc.) into a single map. + }]; + let asmName = "evk_map"; +} + +// Type aliases for op constraints +def Cheddar_CiphertextOrPlaintext : AnyTypeOf<[Cheddar_Ciphertext, Cheddar_Plaintext]>; + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_TD_ diff --git a/tests/Dialect/Cheddar/IR/BUILD b/tests/Dialect/Cheddar/IR/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/Cheddar/IR/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/Cheddar/IR/roundtrip.mlir b/tests/Dialect/Cheddar/IR/roundtrip.mlir new file mode 100644 index 0000000000..ed0ec57a46 --- /dev/null +++ b/tests/Dialect/Cheddar/IR/roundtrip.mlir @@ -0,0 +1,359 @@ +// RUN: heir-opt %s | FileCheck %s + +// Test that the cheddar dialect can be parsed and printed. + +// --- Setup operations --- + +// CHECK: @test_create_context +func.func @test_create_context(%params: !cheddar.parameter) -> !cheddar.context { + // CHECK: cheddar.create_context + %ctx = cheddar.create_context %params : (!cheddar.parameter) -> !cheddar.context + return %ctx : !cheddar.context +} + +// CHECK: @test_create_user_interface +func.func @test_create_user_interface(%ctx: !cheddar.context) -> !cheddar.user_interface { + // CHECK: cheddar.create_user_interface + %ui = cheddar.create_user_interface %ctx : (!cheddar.context) -> !cheddar.user_interface + return %ui : !cheddar.user_interface +} + +// CHECK: @test_get_encoder +func.func @test_get_encoder(%ctx: !cheddar.context) -> !cheddar.encoder { + // CHECK: cheddar.get_encoder + %enc = cheddar.get_encoder %ctx : (!cheddar.context) -> !cheddar.encoder + return %enc : !cheddar.encoder +} + +// CHECK: @test_get_evk_map +func.func @test_get_evk_map(%ui: !cheddar.user_interface) -> !cheddar.evk_map { + // CHECK: cheddar.get_evk_map + %evk = cheddar.get_evk_map %ui : (!cheddar.user_interface) -> !cheddar.evk_map + return %evk : !cheddar.evk_map +} + +// CHECK: @test_get_mult_key +func.func @test_get_mult_key(%ui: !cheddar.user_interface) -> !cheddar.eval_key { + // CHECK: cheddar.get_mult_key + %key = cheddar.get_mult_key %ui : (!cheddar.user_interface) -> !cheddar.eval_key + return %key : !cheddar.eval_key +} + +// CHECK: @test_prepare_rot_key +func.func @test_prepare_rot_key(%ui: !cheddar.user_interface) { + // CHECK: cheddar.prepare_rot_key + // CHECK-SAME: distance = 3 + // CHECK-SAME: maxLevel = 10 + cheddar.prepare_rot_key %ui {distance = 3 : i64, maxLevel = 10 : i64} : (!cheddar.user_interface) -> () + return +} + +// --- Encode / Encrypt / Decrypt --- + +// CHECK: @test_encode +func.func @test_encode( + %enc: !cheddar.encoder, + %msg: tensor<4xf64>) -> !cheddar.plaintext { + // CHECK: cheddar.encode + // CHECK-SAME: level = 5 + // CHECK-SAME: scale = 0x42C0000000000000 + %pt = cheddar.encode %enc, %msg {level = 5 : i64, scale = 35184372088832.0 : f64} : (!cheddar.encoder, tensor<4xf64>) -> !cheddar.plaintext + return %pt : !cheddar.plaintext +} + +// CHECK: @test_encode_constant +func.func @test_encode_constant( + %enc: !cheddar.encoder, + %val: f64) -> !cheddar.constant { + // CHECK: cheddar.encode_constant + // CHECK-SAME: level = 3 + // CHECK-SAME: scale = 0x42C0000000000000 + %c = cheddar.encode_constant %enc, %val {level = 3 : i64, scale = 35184372088832.0 : f64} : (!cheddar.encoder, f64) -> !cheddar.constant + return %c : !cheddar.constant +} + +// CHECK: @test_decode +func.func @test_decode( + %enc: !cheddar.encoder, + %pt: !cheddar.plaintext) -> tensor<4xf64> { + // CHECK: cheddar.decode + %msg = cheddar.decode %enc, %pt : (!cheddar.encoder, !cheddar.plaintext) -> tensor<4xf64> + return %msg : tensor<4xf64> +} + +// CHECK: @test_encrypt +func.func @test_encrypt( + %ui: !cheddar.user_interface, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.encrypt + %ct = cheddar.encrypt %ui, %pt : (!cheddar.user_interface, !cheddar.plaintext) -> !cheddar.ciphertext + return %ct : !cheddar.ciphertext +} + +// CHECK: @test_decrypt +func.func @test_decrypt( + %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext) -> !cheddar.plaintext { + // CHECK: cheddar.decrypt + %pt = cheddar.decrypt %ui, %ct : (!cheddar.user_interface, !cheddar.ciphertext) -> !cheddar.plaintext + return %pt : !cheddar.plaintext +} + +// --- Binary ct-ct operations --- + +// CHECK: @test_add +func.func @test_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.add + %result = cheddar.add %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_sub +func.func @test_sub( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.sub + %result = cheddar.sub %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mult +func.func @test_mult( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.mult + %result = cheddar.mult %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Ct-pt / ct-const operations --- + +// CHECK: @test_add_plain +func.func @test_add_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.add_plain + %result = cheddar.add_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_sub_plain +func.func @test_sub_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.sub_plain + %result = cheddar.sub_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mult_plain +func.func @test_mult_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.mult_plain + %result = cheddar.mult_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_add_const +func.func @test_add_const( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: cheddar.add_const + %result = cheddar.add_const %ctx, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mult_const +func.func @test_mult_const( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: cheddar.mult_const + %result = cheddar.mult_const %ctx, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Unary operations --- + +// CHECK: @test_neg +func.func @test_neg( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.neg + %result = cheddar.neg %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_rescale +func.func @test_rescale( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.rescale + %result = cheddar.rescale %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_level_down +func.func @test_level_down( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.level_down + // CHECK-SAME: targetLevel = 3 + %result = cheddar.level_down %ctx, %ct {targetLevel = 3 : i64} : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Key-switching operations --- + +// CHECK: @test_relinearize +func.func @test_relinearize( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.relinearize + %result = cheddar.relinearize %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_relinearize_rescale +func.func @test_relinearize_rescale( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.relinearize_rescale + %result = cheddar.relinearize_rescale %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Fused compound operations --- + +// CHECK: @test_hmult +func.func @test_hmult( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hmult + %result = cheddar.hmult %ctx, %ct0, %ct1, %key {rescale = true} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hmult_no_rescale +func.func @test_hmult_no_rescale( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hmult + // CHECK-SAME: rescale = false + %result = cheddar.hmult %ctx, %ct0, %ct1, %key {rescale = false} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hrot_static +func.func @test_hrot_static( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.hrot + // CHECK-SAME: static_distance = 5 + %result = cheddar.hrot %ctx, %ct {static_distance = 5 : i64} : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hrot_dynamic +func.func @test_hrot_dynamic( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %dist: index) -> !cheddar.ciphertext { + // CHECK: cheddar.hrot + %result = cheddar.hrot %ctx, %ct, %dist : (!cheddar.context, !cheddar.ciphertext, index) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hrot_add +func.func @test_hrot_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.hrot_add + // CHECK-SAME: distance = 3 + %result = cheddar.hrot_add %ctx, %ct0, %ct1 {distance = 3 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hconj +func.func @test_hconj( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.hconj + %result = cheddar.hconj %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hconj_add +func.func @test_hconj_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.hconj_add + %result = cheddar.hconj_add %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mad_unsafe +func.func @test_mad_unsafe( + %ctx: !cheddar.context, + %acc: !cheddar.ciphertext, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: cheddar.mad_unsafe + %result = cheddar.mad_unsafe %ctx, %acc, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Extension operations --- + +// CHECK: @test_boot +func.func @test_boot( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + // CHECK: cheddar.boot + %result = cheddar.boot %ctx, %ct, %evk : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_linear_transform +func.func @test_linear_transform( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map, + %diags: tensor<2x4xf64>) -> !cheddar.ciphertext { + // CHECK: cheddar.linear_transform + // CHECK-SAME: diagonal_indices = array + // CHECK-SAME: level = 5 + // CHECK-SAME: logBabyStepGiantStepRatio = 0 + %result = cheddar.linear_transform %ctx, %ct, %evk, %diags {diagonal_indices = array, level = 5 : i64, logBabyStepGiantStepRatio = 0 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map, tensor<2x4xf64>) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_eval_poly +func.func @test_eval_poly( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + // CHECK: cheddar.eval_poly + // CHECK-SAME: coefficients = [1.000000e+00, 2.000000e+00, 3.000000e+00] + %result = cheddar.eval_poly %ctx, %ct, %evk {coefficients = [1.0 : f64, 2.0 : f64, 3.0 : f64], level = 5 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} diff --git a/tools/BUILD b/tools/BUILD index ac6ee6edce..01d8f3da1a 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -58,6 +58,7 @@ cc_binary( "@heir//lib/Dialect/CGGI/Transforms", "@heir//lib/Dialect/CKKS/IR:Dialect", "@heir//lib/Dialect/CKKS/Transforms", + "@heir//lib/Dialect/Cheddar/IR:Dialect", "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/Debug/Transforms", @@ -257,6 +258,7 @@ cc_binary( "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/CGGI/IR:Dialect", "@heir//lib/Dialect/CKKS/IR:Dialect", + "@heir//lib/Dialect/Cheddar/IR:Dialect", "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/Jaxite/IR:Dialect", diff --git a/tools/heir-lsp.cpp b/tools/heir-lsp.cpp index 883b0b33c4..622e22c85e 100644 --- a/tools/heir-lsp.cpp +++ b/tools/heir-lsp.cpp @@ -1,6 +1,7 @@ #include "lib/Dialect/BGV/IR/BGVDialect.h" #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CKKS/IR/CKKSDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Debug/IR/DebugDialect.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" @@ -43,6 +44,7 @@ int main(int argc, char** argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index a52619ef4a..0d045495a1 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -16,6 +16,7 @@ #include "lib/Dialect/CGGI/Transforms/Passes.h" #include "lib/Dialect/CKKS/IR/CKKSDialect.h" #include "lib/Dialect/CKKS/Transforms/Passes.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Debug/IR/DebugDialect.h" #include "lib/Dialect/Debug/Transforms/Passes.h" @@ -191,6 +192,7 @@ int main(int argc, char** argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert();