Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ __pycache__/

# lockfile is updated by automation
MODULE.bazel.lock

.jetskicli
13 changes: 13 additions & 0 deletions lib/Dialect/JaxiteWord/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ cc_library(
name = "Dialect",
srcs = ["JaxiteWordDialect.cpp"],
hdrs = [
"JaxiteWordAttributes.h",
"JaxiteWordDialect.h",
"JaxiteWordOps.h",
"JaxiteWordTypes.h",
],
deps = [
":attributes_inc_gen",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
Expand All @@ -32,6 +34,7 @@ cc_library(
td_library(
name = "td_files",
srcs = [
"JaxiteWordAttributes.td",
"JaxiteWordDialect.td",
"JaxiteWordOps.td",
"JaxiteWordTypes.td",
Expand Down Expand Up @@ -63,6 +66,16 @@ add_heir_dialect_library(
],
)

add_heir_dialect_library(
name = "attributes_inc_gen",
dialect = "JaxiteWord",
kind = "attribute",
td_file = "JaxiteWordAttributes.td",
deps = [
":td_files",
],
)

add_heir_dialect_library(
name = "ops_inc_gen",
dialect = "JaxiteWord",
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h.inc"

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_H_
32 changes: 32 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_

include "JaxiteWordDialect.td"

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

class JaxiteWord_Attribute<string attrName, string attrMnemonic>
: AttrDef<JaxiteWord_Dialect, attrName> {
let mnemonic = attrMnemonic;
let assemblyFormat = "`<` struct(params) `>`";
}

def JaxiteWord_CkksParameters : JaxiteWord_Attribute<"CkksParameters", "ckks_parameters"> {
let summary = "Jaxite CKKS parameters";
let description = [{
Parameters for Jaxite CKKS backend.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JianmingTONG Could you provide a quick description of the parameters that would be suitable to include in the description field here? Particularly, r, c, dnum, composite_degree, batch.

}];

let parameters = (ins
"DenseI64ArrayAttr":$q_towers,
"DenseI64ArrayAttr":$p_towers,
"int":$r,
"int":$c,
"int":$dnum,
"int":$composite_degree,
"int":$batch
);
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDATTRIBUTES_TD_
7 changes: 7 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp.inc"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h"
Expand All @@ -8,6 +9,8 @@
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc"
#define GET_OP_CLASSES
Expand All @@ -18,6 +21,10 @@ namespace heir {
namespace jaxiteword {

void JaxiteWordDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc"
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def JaxiteWord_Dialect : Dialect {
let cppNamespace = "::mlir::heir::jaxiteword";

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_
39 changes: 39 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":JaxiteCkksParameterSelection",
":pass_inc_gen",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "JaxiteCkksParameterSelection",
srcs = ["JaxiteCkksParameterSelection.cpp"],
hdrs = ["JaxiteCkksParameterSelection.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/CKKS/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@heir//lib/Parameters:RLWEParams",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

add_heir_transforms(
header_filename = "Passes.h.inc",
pass_name = "JaxiteWord",
td_file = "Passes.td",
)
85 changes: 85 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h"

#include "lib/Dialect/CKKS/IR/CKKSAttributes.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordAttributes.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h"
#include "lib/Parameters/RLWEParams.h"
#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project
#include "llvm/include/llvm/Support/MathExtras.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace jaxiteword {

#define GEN_PASS_DEF_JAXITECKKSPARAMETERSELECTION
#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc"

struct JaxiteCkksParameterSelection
: impl::JaxiteCkksParameterSelectionBase<JaxiteCkksParameterSelection> {
using JaxiteCkksParameterSelectionBase::JaxiteCkksParameterSelectionBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();

auto schemeParamAttr = module->getAttrOfType<ckks::SchemeParamAttr>(
ckks::CKKSDialect::kSchemeParamAttrName);
if (!schemeParamAttr) {
module->emitOpError() << "Missing ckks.schemeParam attribute";
signalPassFailure();
return;
}

int logN = schemeParamAttr.getLogN();
int ringDim = 1 << logN;

auto Q = schemeParamAttr.getQ().asArrayRef();
auto P = schemeParamAttr.getP().asArrayRef();

int totalBitsQ = 0;
for (auto q : Q) {
totalBitsQ += llvm::APInt(64, q).getActiveBits();
}

int totalBitsP = 0;
for (auto p : P) {
totalBitsP += llvm::APInt(64, p).getActiveBits();
}

std::vector<int64_t> existingPrimes;
std::vector<int64_t> qTowers;
std::vector<int64_t> pTowers;

int bitsGeneratedQ = 0;
while (bitsGeneratedQ < totalBitsQ) {
int64_t prime = findPrime(30, ringDim, existingPrimes);
qTowers.push_back(prime);
existingPrimes.push_back(prime);
bitsGeneratedQ += 30;
}

int bitsGeneratedP = 0;
while (bitsGeneratedP < totalBitsP) {
int64_t prime = findPrime(30, ringDim, existingPrimes);
pTowers.push_back(prime);
existingPrimes.push_back(prime);
bitsGeneratedP += 30;
}

auto qTowersAttr = DenseI64ArrayAttr::get(context, qTowers);
auto pTowersAttr = DenseI64ArrayAttr::get(context, pTowers);

int dnum = computeDnum(Q.size() - 1);

// FIXME: Replace dummy values for r, c, composite_degree, and batch.
auto ckksParamsAttr = CkksParametersAttr::get(
context, qTowersAttr, pTowersAttr, 4, 5, dnum, 7, 8);

module->setAttr("jaxiteword.ckks_params", ckksParamsAttr);
}
};

} // namespace jaxiteword
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_
#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace jaxiteword {

#define GEN_PASS_DECL_JAXITECKKSPARAMETERSELECTION
#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc"

} // namespace jaxiteword
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_JAXITECKKSPARAMETERSELECTION_H_
18 changes: 18 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_
#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"
#include "lib/Dialect/JaxiteWord/Transforms/JaxiteCkksParameterSelection.h"

namespace mlir {
namespace heir {
namespace jaxiteword {

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/JaxiteWord/Transforms/Passes.h.inc"

} // namespace jaxiteword
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_H_
16 changes: 16 additions & 0 deletions lib/Dialect/JaxiteWord/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_
#define LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_

include "mlir/Pass/PassBase.td"

def JaxiteCkksParameterSelection : Pass<"jaxite-ckks-parameter-selection", "mlir::ModuleOp"> {
let summary = "Selects parameters for Jaxite CKKS backend";
let description = [{
This pass selects parameters for the Jaxite CKKS backend and annotates them on the module.

(* example filepath=tests/Dialect/JaxiteWord/Transforms/doctest.mlir *)
}];
let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect", "mlir::heir::ckks::CKKSDialect"];
}

#endif // LIB_DIALECT_JAXITEWORD_TRANSFORMS_PASSES_TD_
15 changes: 15 additions & 0 deletions tests/Dialect/JaxiteWord/IR/attr_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: heir-opt %s | FileCheck %s

// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<q_towers = [1, 2], p_towers = [3], r = 4, c = 5, dnum = 6, composite_degree = 7, batch = 8>}
module attributes {
jaxiteword.ckks_params = #jaxiteword.ckks_parameters<
q_towers = [1, 2],
p_towers = [3],
r = 4,
c = 5,
dnum = 6,
composite_degree = 7,
batch = 8
>
} {
}
10 changes: 10 additions & 0 deletions tests/Dialect/JaxiteWord/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
19 changes: 19 additions & 0 deletions tests/Dialect/JaxiteWord/Transforms/doctest.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s

// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>}
!ct = !jaxiteword.ciphertext<2, 3, 4>
!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417>

module attributes {
ckks.schemeParam = #ckks.scheme_param<
logN = 13,
Q = [36028797018652673],
P = [1152921504606994433],
logDefaultScale = 45
>
} {
func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct {
%out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct
return %out : !ct
}
}
32 changes: 32 additions & 0 deletions tests/Dialect/JaxiteWord/Transforms/large_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: heir-opt --jaxite-ckks-parameter-selection %s | FileCheck %s

!ct = !jaxiteword.ciphertext<2, 3, 4>
!ml = !jaxiteword.modulus_list<65536, 1152921504606844513, 1152921504606844417>

// CHECK: module attributes {jaxiteword.ckks_params = #jaxiteword.ckks_parameters<{{.*}}>}
module attributes {
ckks.schemeParam = #ckks.scheme_param<
logN = 13,
Q = [
7896856388305998031, 8335717806483771817, 7621929371556188363, 8941345776919444657,
7943813361973406531, 7742501181933711653, 7673257225347932497, 7210067971330841557,
8234891178228564671, 7847526270039855001, 8245181310374330081, 8960862465870304837,
8718902402328186751, 9031509869954283143, 7789630786405883791, 8945030373143909771,
7258099451375055763, 8999881575504424663, 9020740517063589967, 7906610589161779643,
7256670403940451583, 7215881909751066997, 7261482118667644289, 6918930965025587023,
7552875336759771971, 7264322706790679029, 7035727842643806041, 8663275797836175071,
7348375621176293489, 8101412547026401381
],
P = [
8046990677865391223, 8262056840302532089, 7520591891579404973, 8469636204033924593,
7515061052621148421, 8671733300942445233, 9061065578563297193, 8446495666365292607,
8329800933433096669, 7565030516258039723
],
logDefaultScale = 45
>
} {
func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct {
%out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct
return %out : !ct
}
}
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cc_binary(
"@heir//lib/Dialect/Debug/Transforms",
"@heir//lib/Dialect/Jaxite/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/Transforms",
"@heir//lib/Dialect/KeyMgmt/IR:Dialect",
"@heir//lib/Dialect/LWE/Conversions/LWEToLattigo",
"@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe",
Expand Down
Loading
Loading