From fd512335d77a02054baef7dda6dbf1fde031495a Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 May 2026 15:02:22 -0700 Subject: [PATCH 1/8] add templates.py output for JaxiteCkksParameterSelection --- lib/Dialect/JaxiteWord/Transforms/BUILD | 37 +++++++++++++++++++ .../JaxiteCkksParameterSelection.cpp | 32 ++++++++++++++++ .../Transforms/JaxiteCkksParameterSelection.h | 17 +++++++++ lib/Dialect/JaxiteWord/Transforms/Passes.h | 18 +++++++++ lib/Dialect/JaxiteWord/Transforms/Passes.td | 14 +++++++ 5 files changed, 118 insertions(+) create mode 100644 lib/Dialect/JaxiteWord/Transforms/BUILD create mode 100644 lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp create mode 100644 lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h create mode 100644 lib/Dialect/JaxiteWord/Transforms/Passes.h create mode 100644 lib/Dialect/JaxiteWord/Transforms/Passes.td diff --git a/lib/Dialect/JaxiteWord/Transforms/BUILD b/lib/Dialect/JaxiteWord/Transforms/BUILD new file mode 100644 index 0000000000..c94d9c40b3 --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/BUILD @@ -0,0 +1,37 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Transforms", + hdrs = ["Passes.h"], + deps = [ + ":JaxiteCkksParameterSelection", + ":pass_inc_gen", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "JaxiteCkksParameterSelection", + srcs = ["JaxiteCkksParameterSelection.cpp"], + hdrs = ["JaxiteCkksParameterSelection.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +add_heir_transforms( + header_filename = "Passes.h.inc", + pass_name = "JaxiteWord", + td_file = "Passes.td", +) diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp new file mode 100644 index 0000000000..67f7157deb --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp @@ -0,0 +1,32 @@ +#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h" + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace jaxite_word { + +#define GEN_PASS_DEF_JAXITECKKSPARAMETERSELECTION +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" + +struct JaxiteCkksParameterSelection + : impl::JaxiteCkksParameterSelectionBase { + using JaxiteCkksParameterSelectionBase::JaxiteCkksParameterSelectionBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // FIXME: implement pass + patterns.add<>(context); + + // TODO (#1221): Investigate whether folding (default: on) can be skipped + // here. + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace jaxite_word +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h new file mode 100644 index 0000000000..1c643fff9d --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h @@ -0,0 +1,17 @@ +#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_ +#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace jaxite_word { + +#define GEN_PASS_DECL_JAXITECKKSPARAMETERSELECTION +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" + +} // namespace jaxite_word +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_ diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.h b/lib/Dialect/JaxiteWord/Transforms/Passes.h new file mode 100644 index 0000000000..1aae90a773 --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.h @@ -0,0 +1,18 @@ +#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_ +#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_ + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h" + +namespace mlir { +namespace heir { +namespace jaxite_word { + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" + +} // namespace jaxite_word +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_ diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.td b/lib/Dialect/JaxiteWord/Transforms/Passes.td new file mode 100644 index 0000000000..52e43726f2 --- /dev/null +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.td @@ -0,0 +1,14 @@ +#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ +#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection"> { + // FIXME: add add summary/description + let summary = ""; + let description = [{ + }]; + let dependentDialects = ["mlir::heir::jaxite_word::JaxiteWordDialect"]; +} + +#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ From 977ea13a27e0f701215ccafba2a7a9cef2d78649 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 May 2026 15:18:47 -0700 Subject: [PATCH 2/8] update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 58f129a20a..31b4c95304 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,5 @@ __pycache__/ # lockfile is updated by automation MODULE.bazel.lock + +.jetskicli From 983bc41b965c829ecc7e01e30f0eb7f205886268 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 May 2026 15:20:17 -0700 Subject: [PATCH 3/8] add paramters attribute --- lib/Dialect/JaxiteWord/IR/BUILD | 13 ++++++++ .../JaxiteWord/IR/JaxiteWordAttributes.h | 9 ++++++ .../JaxiteWord/IR/JaxiteWordAttributes.td | 32 +++++++++++++++++++ .../JaxiteWord/IR/JaxiteWordDialect.cpp | 7 ++++ .../JaxiteWord/IR/JaxiteWordDialect.td | 1 + tests/Dialect/JaxiteWord/IR/attr_test.mlir | 15 +++++++++ 6 files changed, 77 insertions(+) create mode 100644 lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h create mode 100644 lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td create mode 100644 tests/Dialect/JaxiteWord/IR/attr_test.mlir diff --git a/lib/Dialect/JaxiteWord/IR/BUILD b/lib/Dialect/JaxiteWord/IR/BUILD index 5daaeb3996..3043cfba34 100644 --- a/lib/Dialect/JaxiteWord/IR/BUILD +++ b/lib/Dialect/JaxiteWord/IR/BUILD @@ -13,11 +13,13 @@ cc_library( name = "Dialect", srcs = ["JaxiteWordDialect.cpp"], hdrs = [ + "JaxiteWordAttributes.h", "JaxiteWordDialect.h", "JaxiteWordOps.h", "JaxiteWordTypes.h", ], deps = [ + ":attributes_inc_gen", ":dialect_inc_gen", ":ops_inc_gen", ":types_inc_gen", @@ -32,6 +34,7 @@ cc_library( td_library( name = "td_files", srcs = [ + "JaxiteWordAttributes.td", "JaxiteWordDialect.td", "JaxiteWordOps.td", "JaxiteWordTypes.td", @@ -63,6 +66,16 @@ add_heir_dialect_library( ], ) +add_heir_dialect_library( + name = "attributes_inc_gen", + dialect = "JaxiteWord", + kind = "attribute", + td_file = "JaxiteWordAttributes.td", + deps = [ + ":td_files", + ], +) + add_heir_dialect_library( name = "ops_inc_gen", dialect = "JaxiteWord", diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h new file mode 100644 index 0000000000..fdf0223041 --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h @@ -0,0 +1,9 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_ + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" + +#define GET_ATTRDEF_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h.inc" + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td new file mode 100644 index 0000000000..fed739b97f --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td @@ -0,0 +1,32 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_ + +include "JaxiteWordDialect.td" + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +class JaxiteWord_Attribute + : AttrDef { + let mnemonic = attrMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def JaxiteWord_CkksParameters : JaxiteWord_Attribute<"CkksParameters", "ckks_parameters"> { + let summary = "Jaxite CKKS parameters"; + let description = [{ + Parameters for Jaxite CKKS backend. + }]; + + let parameters = (ins + "DenseI64ArrayAttr":$q_towers, + "DenseI64ArrayAttr":$p_towers, + "int":$r, + "int":$c, + "int":$dnum, + "int":$composite_degree, + "int":$batch + ); +} + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp index 0da3ab19a0..c0ada2e151 100644 --- a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp @@ -1,5 +1,6 @@ #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp.inc" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h" @@ -8,6 +9,8 @@ #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#define GET_ATTRDEF_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES #include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc" #define GET_OP_CLASSES @@ -18,6 +21,10 @@ namespace heir { namespace jaxiteword { void JaxiteWordDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc" + >(); addTypes< #define GET_TYPEDEF_LIST #include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc" diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td index 5837a68a27..c5719de9ca 100644 --- a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td @@ -17,6 +17,7 @@ def JaxiteWord_Dialect : Dialect { let cppNamespace = "::mlir::heir::jaxiteword"; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } #endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ diff --git a/tests/Dialect/JaxiteWord/IR/attr_test.mlir b/tests/Dialect/JaxiteWord/IR/attr_test.mlir new file mode 100644 index 0000000000..ce6840f24b --- /dev/null +++ b/tests/Dialect/JaxiteWord/IR/attr_test.mlir @@ -0,0 +1,15 @@ +// RUN: heir-opt %s | FileCheck %s + +// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters} +module attributes { + jaxiteword.ckks_params = #jaxiteword.ckks_parameters< + q_towers = [1, 2], + p_towers = [3], + r = 4, + c = 5, + dnum = 6, + composite_degree = 7, + batch = 8 + > +} { +} From 0ef6d3e894527e5006769fcc3d2d08a99344e65f Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 May 2026 15:26:09 -0700 Subject: [PATCH 4/8] add dummy implementation of the pass --- .../JaxiteCkksParameterSelection.cpp | 23 +++++++++++-------- .../Transforms/JaxiteCkksParameterSelection.h | 4 ++-- lib/Dialect/JaxiteWord/Transforms/Passes.h | 4 ++-- lib/Dialect/JaxiteWord/Transforms/Passes.td | 10 ++++---- .../JaxiteWord/Transforms/doctest.mlir | 10 ++++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 ++ 7 files changed, 37 insertions(+), 17 deletions(-) create mode 100644 tests/Dialect/JaxiteWord/Transforms/doctest.mlir diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp index 67f7157deb..3bdeaf9591 100644 --- a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp @@ -1,11 +1,12 @@ #include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project namespace mlir { namespace heir { -namespace jaxite_word { +namespace jaxiteword { #define GEN_PASS_DEF_JAXITECKKSPARAMETERSELECTION #include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" @@ -16,17 +17,21 @@ struct JaxiteCkksParameterSelection void runOnOperation() override { MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); + ModuleOp module = getOperation(); - // FIXME: implement pass - patterns.add<>(context); + SmallVector qTowers = {1, 2}; + SmallVector pTowers = {3}; - // TODO (#1221): Investigate whether folding (default: on) can be skipped - // here. - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + auto qTowersAttr = DenseI64ArrayAttr::get(context, qTowers); + auto pTowersAttr = DenseI64ArrayAttr::get(context, pTowers); + + auto ckksParamsAttr = CkksParametersAttr::get(context, qTowersAttr, + pTowersAttr, 4, 5, 6, 7, 8); + + module->setAttr("jaxiteword.ckks_params", ckksParamsAttr); } }; -} // namespace jaxite_word +} // namespace jaxiteword } // namespace heir } // namespace mlir diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h index 1c643fff9d..471662d5bd 100644 --- a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h @@ -5,12 +5,12 @@ namespace mlir { namespace heir { -namespace jaxite_word { +namespace jaxiteword { #define GEN_PASS_DECL_JAXITECKKSPARAMETERSELECTION #include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" -} // namespace jaxite_word +} // namespace jaxiteword } // namespace heir } // namespace mlir diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.h b/lib/Dialect/JaxiteWord/Transforms/Passes.h index 1aae90a773..067fc1d665 100644 --- a/lib/Dialect/JaxiteWord/Transforms/Passes.h +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.h @@ -6,12 +6,12 @@ namespace mlir { namespace heir { -namespace jaxite_word { +namespace jaxiteword { #define GEN_PASS_REGISTRATION #include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc" -} // namespace jaxite_word +} // namespace jaxiteword } // namespace heir } // namespace mlir diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.td b/lib/Dialect/JaxiteWord/Transforms/Passes.td index 52e43726f2..4c1fd9e835 100644 --- a/lib/Dialect/JaxiteWord/Transforms/Passes.td +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.td @@ -3,12 +3,14 @@ include "mlir/Pass/PassBase.td" -def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection"> { - // FIXME: add add summary/description - let summary = ""; +def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection", "mlir::ModuleOp"> { + let summary = "Selects parameters for Jaxite CKKS backend"; let description = [{ + This pass selects parameters for the Jaxite CKKS backend and annotates them on the module. + + (* example filepath=tests/Dialect/JaxiteWord/Transforms/doctest.mlir *) }]; - let dependentDialects = ["mlir::heir::jaxite_word::JaxiteWordDialect"]; + let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect"]; } #endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ diff --git a/tests/Dialect/JaxiteWord/Transforms/doctest.mlir b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir new file mode 100644 index 0000000000..742d57a636 --- /dev/null +++ b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir @@ -0,0 +1,10 @@ +// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s + +// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>} +!ct = !jaxiteword.ciphertext<2, 3, 4> +!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> + +func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct +} diff --git a/tools/BUILD b/tools/BUILD index 26c7ec3076..6d457dafb0 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -63,6 +63,7 @@ cc_binary( "@heir//lib/Dialect/Debug/Transforms", "@heir//lib/Dialect/Jaxite/IR:Dialect", "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Dialect/JaxiteWord/Transforms", "@heir//lib/Dialect/KeyMgmt/IR:Dialect", "@heir//lib/Dialect/LWE/Conversions/LWEToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 9327f04f44..c4f76b36ee 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -22,6 +22,7 @@ #include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/Transforms/Passes.h" #include "lib/Dialect/KeyMgmt/IR/KeyMgmtDialect.h" #include "lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" @@ -297,6 +298,7 @@ int main(int argc, char** argv) { cggi::registerCGGIPasses(); debug::registerDebugPasses(); ckks::registerCKKSPasses(); + jaxiteword::registerJaxiteWordPasses(); lattigo::registerLattigoPasses(); lwe::registerLWEPasses(); mgmt::registerMgmtPasses(); From 3662527cc0187baf10ffc160e293de33123a8b07 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 May 2026 15:38:33 -0700 Subject: [PATCH 5/8] convert 64-bit prime moduli from ckks params to 30-bit analogues --- lib/Dialect/JaxiteWord/Transforms/BUILD | 2 + .../JaxiteCkksParameterSelection.cpp | 58 +++++++++++++++++-- lib/Dialect/JaxiteWord/Transforms/Passes.td | 2 +- .../JaxiteWord/Transforms/doctest.mlir | 15 ++++- 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/JaxiteWord/Transforms/BUILD b/lib/Dialect/JaxiteWord/Transforms/BUILD index c94d9c40b3..ae06c1fb19 100644 --- a/lib/Dialect/JaxiteWord/Transforms/BUILD +++ b/lib/Dialect/JaxiteWord/Transforms/BUILD @@ -23,7 +23,9 @@ cc_library( hdrs = ["JaxiteCkksParameterSelection.h"], deps = [ ":pass_inc_gen", + "@heir//lib/Dialect/CKKS/IR:Dialect", "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Parameters:RLWEParams", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp index 3bdeaf9591..4c70f17a5c 100644 --- a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp @@ -1,8 +1,12 @@ #include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h" +#include "lib/Dialect/CKKS/IR/CKKSAttributes.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "lib/Parameters/RLWEParams.h" +#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project +#include "llvm/include/llvm/Support/MathExtras.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project namespace mlir { namespace heir { @@ -19,14 +23,58 @@ struct JaxiteCkksParameterSelection MLIRContext *context = &getContext(); ModuleOp module = getOperation(); - SmallVector qTowers = {1, 2}; - SmallVector pTowers = {3}; + auto schemeParamAttr = module->getAttrOfType( + ckks::CKKSDialect::kSchemeParamAttrName); + if (!schemeParamAttr) { + module->emitOpError() << "Missing ckks.schemeParam attribute"; + signalPassFailure(); + return; + } + + int logN = schemeParamAttr.getLogN(); + int ringDim = 1 << logN; + + auto Q = schemeParamAttr.getQ().asArrayRef(); + auto P = schemeParamAttr.getP().asArrayRef(); + + int totalBitsQ = 0; + for (auto q : Q) { + totalBitsQ += llvm::APInt(64, q).getActiveBits(); + } + + int totalBitsP = 0; + for (auto p : P) { + totalBitsP += llvm::APInt(64, p).getActiveBits(); + } + + std::vector existingPrimes; + std::vector qTowers; + std::vector pTowers; + + int bitsGeneratedQ = 0; + while (bitsGeneratedQ < totalBitsQ) { + int64_t prime = findPrime(30, ringDim, existingPrimes); + qTowers.push_back(prime); + existingPrimes.push_back(prime); + bitsGeneratedQ += 30; + } + + int bitsGeneratedP = 0; + while (bitsGeneratedP < totalBitsP) { + int64_t prime = findPrime(30, ringDim, existingPrimes); + pTowers.push_back(prime); + existingPrimes.push_back(prime); + bitsGeneratedP += 30; + } auto qTowersAttr = DenseI64ArrayAttr::get(context, qTowers); auto pTowersAttr = DenseI64ArrayAttr::get(context, pTowers); - auto ckksParamsAttr = CkksParametersAttr::get(context, qTowersAttr, - pTowersAttr, 4, 5, 6, 7, 8); + int dnum = computeDnum(Q.size() - 1); + + // FIXME: Replace dummy values for r, c, composite_degree, and batch. + auto ckksParamsAttr = CkksParametersAttr::get( + context, qTowersAttr, pTowersAttr, 4, 5, dnum, 7, 8); module->setAttr("jaxiteword.ckks_params", ckksParamsAttr); } diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.td b/lib/Dialect/JaxiteWord/Transforms/Passes.td index 4c1fd9e835..e9a8fb08d5 100644 --- a/lib/Dialect/JaxiteWord/Transforms/Passes.td +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.td @@ -10,7 +10,7 @@ def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection", "mlir (* example filepath=tests/Dialect/JaxiteWord/Transforms/doctest.mlir *) }]; - let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect"]; + let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect", "mlir::heir::ckks::CKKSDialect"]; } #endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_ diff --git a/tests/Dialect/JaxiteWord/Transforms/doctest.mlir b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir index 742d57a636..cdd40430ef 100644 --- a/tests/Dialect/JaxiteWord/Transforms/doctest.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir @@ -4,7 +4,16 @@ !ct = !jaxiteword.ciphertext<2, 3, 4> !ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> -func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { - %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct - return %out : !ct +module attributes { + ckks.schemeParam = #ckks.scheme_param< + logN = 13, + Q = [36028797018652673], + P = [1152921504606994433], + logDefaultScale = 45 + > +} { + func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct + } } From 66ed3d192e8464140ae4dae2274863a50a2866ba Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 May 2026 15:40:56 -0700 Subject: [PATCH 6/8] add a larger test --- tests/Dialect/JaxiteWord/Transforms/BUILD | 10 ++++++ .../JaxiteWord/Transforms/large_test.mlir | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 tests/Dialect/JaxiteWord/Transforms/BUILD create mode 100644 tests/Dialect/JaxiteWord/Transforms/large_test.mlir diff --git a/tests/Dialect/JaxiteWord/Transforms/BUILD b/tests/Dialect/JaxiteWord/Transforms/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/JaxiteWord/Transforms/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/JaxiteWord/Transforms/large_test.mlir b/tests/Dialect/JaxiteWord/Transforms/large_test.mlir new file mode 100644 index 0000000000..7c175bbb9d --- /dev/null +++ b/tests/Dialect/JaxiteWord/Transforms/large_test.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s + +!ct = !jaxiteword.ciphertext<2, 3, 4> +!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> + +// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>} +module attributes { + ckks.schemeParam = #ckks.scheme_param< + logN = 13, + Q = [ + 7896856388305998031, 8335717806483771817, 7621929371556188363, 8941345776919444657, + 7943813361973406531, 7742501181933711653, 7673257225347932497, 7210067971330841557, + 8234891178228564671, 7847526270039855001, 8245181310374330081, 8960862465870304837, + 8718902402328186751, 9031509869954283143, 7789630786405883791, 8945030373143909771, + 7258099451375055763, 8999881575504424663, 9020740517063589967, 7906610589161779643, + 7256670403940451583, 7215881909751066997, 7261482118667644289, 6918930965025587023, + 7552875336759771971, 7264322706790679029, 7035727842643806041, 8663275797836175071, + 7348375621176293489, 8101412547026401381 + ], + P = [ + 8046990677865391223, 8262056840302532089, 7520591891579404973, 8469636204033924593, + 7515061052621148421, 8671733300942445233, 9061065578563297193, 8446495666365292607, + 8329800933433096669, 7565030516258039723 + ], + logDefaultScale = 45 + > +} { + func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct + } +} From d62c275b543184fc00d5520bf279492d8294f487 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 11 May 2026 16:00:10 -0700 Subject: [PATCH 7/8] update some default values --- .../JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp index 4c70f17a5c..761934718b 100644 --- a/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp +++ b/lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp @@ -72,9 +72,9 @@ struct JaxiteCkksParameterSelection int dnum = computeDnum(Q.size() - 1); - // FIXME: Replace dummy values for r, c, composite_degree, and batch. + // FIXME: Replace dummy value for composite_degree. auto ckksParamsAttr = CkksParametersAttr::get( - context, qTowersAttr, pTowersAttr, 4, 5, dnum, 7, 8); + context, qTowersAttr, pTowersAttr, 4, 4, dnum, 7, 1); module->setAttr("jaxiteword.ckks_params", ckksParamsAttr); } From 56f4e6a3b4bec9fb3d2f7fdf492d80bc2515e07d Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 11 May 2026 16:00:44 -0700 Subject: [PATCH 8/8] make test more flexible --- tests/Dialect/JaxiteWord/Transforms/doctest.mlir | 2 +- tests/Dialect/JaxiteWord/Transforms/large_test.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Dialect/JaxiteWord/Transforms/doctest.mlir b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir index cdd40430ef..ee7b2ba0f5 100644 --- a/tests/Dialect/JaxiteWord/Transforms/doctest.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/doctest.mlir @@ -1,6 +1,6 @@ // RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s -// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>} +// CHECK: jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}> !ct = !jaxiteword.ciphertext<2, 3, 4> !ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> diff --git a/tests/Dialect/JaxiteWord/Transforms/large_test.mlir b/tests/Dialect/JaxiteWord/Transforms/large_test.mlir index 7c175bbb9d..827e36fa0e 100644 --- a/tests/Dialect/JaxiteWord/Transforms/large_test.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/large_test.mlir @@ -3,7 +3,7 @@ !ct = !jaxiteword.ciphertext<2, 3, 4> !ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417> -// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>} +// CHECK: jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}> module attributes { ckks.schemeParam = #ckks.scheme_param< logN = 13,