diff --git a/.gitignore b/.gitignore index 58f129a20a..31b4c95304 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,5 @@ __pycache__/ # lockfile is updated by automation MODULE.bazel.lock + +.jetskicli diff --git a/lib/Dialect/JaxiteWord/IR/BUILD b/lib/Dialect/JaxiteWord/IR/BUILD index 5daaeb3996..3043cfba34 100644 --- a/lib/Dialect/JaxiteWord/IR/BUILD +++ b/lib/Dialect/JaxiteWord/IR/BUILD @@ -13,11 +13,13 @@ cc_library( name = "Dialect", srcs = ["JaxiteWordDialect.cpp"], hdrs = [ + "JaxiteWordAttributes.h", "JaxiteWordDialect.h", "JaxiteWordOps.h", "JaxiteWordTypes.h", ], deps = [ + ":attributes_inc_gen", ":dialect_inc_gen", ":ops_inc_gen", ":types_inc_gen", @@ -32,6 +34,7 @@ cc_library( td_library( name = "td_files", srcs = [ + "JaxiteWordAttributes.td", "JaxiteWordDialect.td", "JaxiteWordOps.td", "JaxiteWordTypes.td", @@ -63,6 +66,16 @@ add_heir_dialect_library( ], ) +add_heir_dialect_library( + name = "attributes_inc_gen", + dialect = "JaxiteWord", + kind = "attribute", + td_file = "JaxiteWordAttributes.td", + deps = [ + ":td_files", + ], +) + add_heir_dialect_library( name = "ops_inc_gen", dialect = "JaxiteWord", diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h new file mode 100644 index 0000000000..fdf0223041 --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h @@ -0,0 +1,9 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_ + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" + +#define GET_ATTRDEF_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h.inc" + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td new file mode 100644 index 0000000000..fed739b97f --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td @@ -0,0 +1,32 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_ + +include "JaxiteWordDialect.td" + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +class JaxiteWord_Attribute + : AttrDef { + let mnemonic = attrMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def JaxiteWord_CkksParameters : JaxiteWord_Attribute<"CkksParameters", "ckks_parameters"> { + let summary = "Jaxite CKKS parameters"; + let description = [{ + Parameters for Jaxite CKKS backend. + }]; + + let parameters = (ins + "DenseI64ArrayAttr":$q_towers, + "DenseI64ArrayAttr":$p_towers, + "int":$r, + "int":$c, + "int":$dnum, + "int":$composite_degree, + "int":$batch + ); +} + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp index 0da3ab19a0..c0ada2e151 100644 --- a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp @@ -1,5 +1,6 @@ #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp.inc" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h" @@ -8,6 +9,8 @@ #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#define GET_ATTRDEF_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES #include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc" #define GET_OP_CLASSES @@ -18,6 +21,10 @@ namespace heir { namespace jaxiteword { void JaxiteWordDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc" + >(); addTypes< #define GET_TYPEDEF_LIST #include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc" diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td index 5837a68a27..c5719de9ca 100644 --- a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td @@ -17,6 +17,7 @@ def JaxiteWord_Dialect : Dialect { let cppNamespace = "::mlir::heir::jaxiteword"; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } #endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ diff --git a/lib/Dialect/JaxiteWord/Transforms/BUILD b/lib/Dialect/JaxiteWord/Transforms/BUILD new file mode 100644 index 0000000000..ae06c1fb19 --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/BUILD @@ -0,0 +1,39 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Transforms", + hdrs = ["Passes.h"], + deps = [ + ":JaxiteCkksParameterSelection", + ":pass_inc_gen", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "JaxiteCkksParameterSelection", + srcs = ["JaxiteCkksParameterSelection.cpp"], + hdrs = ["JaxiteCkksParameterSelection.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/CKKS/IR:Dialect", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Parameters:RLWEParams", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +add_heir_transforms( + header_filename = "Passes.h.inc", + pass_name = "JaxiteWord", + td_file = "Passes.td", +) diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp new file mode 100644 index 0000000000..761934718b --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp @@ -0,0 +1,85 @@ +#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h" + +#include "lib/Dialect/CKKS/IR/CKKSAttributes.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "lib/Parameters/RLWEParams.h" +#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project +#include "llvm/include/llvm/Support/MathExtras.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace jaxiteword { + +#define GEN_PASS_DEF_JAXITECKKSPARAMETERSELECTION +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" + +struct JaxiteCkksParameterSelection + : impl::JaxiteCkksParameterSelectionBase { + using JaxiteCkksParameterSelectionBase::JaxiteCkksParameterSelectionBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto schemeParamAttr = module->getAttrOfType( + ckks::CKKSDialect::kSchemeParamAttrName); + if (!schemeParamAttr) { + module->emitOpError() << "Missing ckks.schemeParam attribute"; + signalPassFailure(); + return; + } + + int logN = schemeParamAttr.getLogN(); + int ringDim = 1 << logN; + + auto Q = schemeParamAttr.getQ().asArrayRef(); + auto P = schemeParamAttr.getP().asArrayRef(); + + int totalBitsQ = 0; + for (auto q : Q) { + totalBitsQ += llvm::APInt(64, q).getActiveBits(); + } + + int totalBitsP = 0; + for (auto p : P) { + totalBitsP += llvm::APInt(64, p).getActiveBits(); + } + + std::vector existingPrimes; + std::vector qTowers; + std::vector pTowers; + + int bitsGeneratedQ = 0; + while (bitsGeneratedQ < totalBitsQ) { + int64_t prime = findPrime(30, ringDim, existingPrimes); + qTowers.push_back(prime); + existingPrimes.push_back(prime); + bitsGeneratedQ += 30; + } + + int bitsGeneratedP = 0; + while (bitsGeneratedP < totalBitsP) { + int64_t prime = findPrime(30, ringDim, existingPrimes); + pTowers.push_back(prime); + existingPrimes.push_back(prime); + bitsGeneratedP += 30; + } + + auto qTowersAttr = DenseI64ArrayAttr::get(context, qTowers); + auto pTowersAttr = DenseI64ArrayAttr::get(context, pTowers); + + int dnum = computeDnum(Q.size() - 1); + + // FIXME: Replace dummy value for composite_degree. + auto ckksParamsAttr = CkksParametersAttr::get( + context, qTowersAttr, pTowersAttr, 4, 4, dnum, 7, 1); + + module->setAttr("jaxiteword.ckks_params", ckksParamsAttr); + } +}; + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h new file mode 100644 index 0000000000..471662d5bd --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h @@ -0,0 +1,17 @@ +#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_ +#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace jaxiteword { + +#define GEN_PASS_DECL_JAXITECKKSPARAMETERSELECTION +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_ diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.h b/lib/Dialect/JaxiteWord/Transforms/Passes.h new file mode 100644 index 0000000000..067fc1d665 --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.h @@ -0,0 +1,18 @@ +#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_ +#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_ + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h" + +namespace mlir { +namespace heir { +namespace jaxiteword { + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_ diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.td b/lib/Dialect/JaxiteWord/Transforms/Passes.td new file mode 100644 index 0000000000..e9a8fb08d5 --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.td @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ +#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection", "mlir::ModuleOp"> { + let summary = "Selects parameters for Jaxite CKKS backend"; + let description = [{ + This pass selects parameters for the Jaxite CKKS backend and annotates them on the module. + + (* example filepath=tests/Dialect/JaxiteWord/Transforms/doctest.mlir *) + }]; + let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect", "mlir::heir::ckks::CKKSDialect"]; +} + +#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ diff --git a/tests/Dialect/JaxiteWord/IR/attr_test.mlir b/tests/Dialect/JaxiteWord/IR/attr_test.mlir new file mode 100644 index 0000000000..ce6840f24b --- /dev/null +++ b/tests/Dialect/JaxiteWord/IR/attr_test.mlir @@ -0,0 +1,15 @@ +// RUN: heir-opt %s | FileCheck %s + +// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters} +module attributes { + jaxiteword.ckks_params = #jaxiteword.ckks_parameters< + q_towers = [1, 2], + p_towers = [3], + r = 4, + c = 5, + dnum = 6, + composite_degree = 7, + batch = 8 + > +} { +} diff --git a/tests/Dialect/JaxiteWord/Transforms/BUILD b/tests/Dialect/JaxiteWord/Transforms/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/JaxiteWord/Transforms/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/JaxiteWord/Transforms/doctest.mlir b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir new file mode 100644 index 0000000000..ee7b2ba0f5 --- /dev/null +++ b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir @@ -0,0 +1,19 @@ +// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s + +// CHECK: jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}> +!ct = !jaxiteword.ciphertext<2, 3, 4> +!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> + +module attributes { + ckks.schemeParam = #ckks.scheme_param< + logN = 13, + Q = [36028797018652673], + P = [1152921504606994433], + logDefaultScale = 45 + > +} { + func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct + } +} diff --git a/tests/Dialect/JaxiteWord/Transforms/large_test.mlir b/tests/Dialect/JaxiteWord/Transforms/large_test.mlir new file mode 100644 index 0000000000..827e36fa0e --- /dev/null +++ b/tests/Dialect/JaxiteWord/Transforms/large_test.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s + +!ct = !jaxiteword.ciphertext<2, 3, 4> +!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> + +// CHECK: jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}> +module attributes { + ckks.schemeParam = #ckks.scheme_param< + logN = 13, + Q = [ + 7896856388305998031, 8335717806483771817, 7621929371556188363, 8941345776919444657, + 7943813361973406531, 7742501181933711653, 7673257225347932497, 7210067971330841557, + 8234891178228564671, 7847526270039855001, 8245181310374330081, 8960862465870304837, + 8718902402328186751, 9031509869954283143, 7789630786405883791, 8945030373143909771, + 7258099451375055763, 8999881575504424663, 9020740517063589967, 7906610589161779643, + 7256670403940451583, 7215881909751066997, 7261482118667644289, 6918930965025587023, + 7552875336759771971, 7264322706790679029, 7035727842643806041, 8663275797836175071, + 7348375621176293489, 8101412547026401381 + ], + P = [ + 8046990677865391223, 8262056840302532089, 7520591891579404973, 8469636204033924593, + 7515061052621148421, 8671733300942445233, 9061065578563297193, 8446495666365292607, + 8329800933433096669, 7565030516258039723 + ], + logDefaultScale = 45 + > +} { + func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct + } +} diff --git a/tools/BUILD b/tools/BUILD index 26c7ec3076..6d457dafb0 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -63,6 +63,7 @@ cc_binary( "@heir//lib/Dialect/Debug/Transforms", "@heir//lib/Dialect/Jaxite/IR:Dialect", "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Dialect/JaxiteWord/Transforms", "@heir//lib/Dialect/KeyMgmt/IR:Dialect", "@heir//lib/Dialect/LWE/Conversions/LWEToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 9327f04f44..c4f76b36ee 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -22,6 +22,7 @@ #include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h" #include "lib/Dialect/KeyMgmt/IR/KeyMgmtDialect.h" #include "lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" @@ -297,6 +298,7 @@ int main(int argc, char** argv) { cggi::registerCGGIPasses(); debug::registerDebugPasses(); ckks::registerCKKSPasses(); + jaxiteword::registerJaxiteWordPasses(); lattigo::registerLattigoPasses(); lwe::registerLWEPasses(); mgmt::registerMgmtPasses();