diff --git a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp index e6cb53acf3..088d7afd06 100644 --- a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp @@ -669,9 +669,13 @@ struct ConvertOrionChebyshevOp LLVM_DEBUG(llvm::dbgs() << "Using default scale: " << defaultScale.getInt() << "\n"); + auto domainAttr = + rewriter.getDenseF64ArrayAttr({op.getDomainStart().getValueAsDouble(), + op.getDomainEnd().getValueAsDouble()}); auto chebyshevOp = lattigo::CKKSChebyshevOp::create( rewriter, op.getLoc(), adaptor.getInput().getType(), polyEvaluator, - adaptor.getInput(), adaptor.getCoefficients(), defaultScale); + adaptor.getInput(), adaptor.getCoefficients(), defaultScale, + domainAttr); rewriter.replaceOp(op, chebyshevOp.getResult()); return success(); diff --git a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td index 015cbb8547..0015d31388 100644 --- a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td @@ -426,7 +426,8 @@ def Lattigo_CKKSChebyshevOp : Lattigo_CKKSOp<"chebyshev", [AllTypesMatch<["ciphe Lattigo_CKKSPolynomialEvaluator:$evaluator, Lattigo_RLWECiphertext:$ciphertext, ArrayAttr:$coefficients, - Builtin_IntegerAttr:$targetScale + Builtin_IntegerAttr:$targetScale, + OptionalAttr]>>:$domain ); let results = (outs Lattigo_RLWECiphertext:$output); } diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index e20757e3dc..8c93455105 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -2113,10 +2113,19 @@ LogicalResult LattigoEmitter::printOperation(CKKSChebyshevOp op) { os.unindent(); os << "}\n"; std::string bignumPoly = getName(op.getOutput()) + "_bignumPoly"; + std::string intervalArg = "nil"; // indicates "default" to lattigo + if (DenseF64ArrayAttr domainAttr = op.getDomainAttr()) { + ArrayRef domain = domainAttr.asArrayRef(); + std::string intervalName = getName(op.getOutput()) + "_interval"; + std::ostringstream startStream, endStream; + startStream << std::scientific << domain[0]; + endStream << std::scientific << domain[1]; + os << intervalName << " := [2]float64{" << startStream.str() << ", " + << endStream.str() << "}\n"; + intervalArg = intervalName; + } os << bignumPoly << " := bignum.NewPolynomial(bignum.Chebyshev, " - << polyCoeffs - << ", " - "nil)\n"; + << polyCoeffs << ", " << intervalArg << ")\n"; std::string resultName = getName(op.getOutput()); os << resultName << ", " << errName << " := "; os << getName(op.getEvaluator()) << ".Evaluate("; diff --git a/tests/Emitter/Lattigo/chebyshev.mlir b/tests/Emitter/Lattigo/chebyshev.mlir new file mode 100644 index 0000000000..a7f0f6d6a8 --- /dev/null +++ b/tests/Emitter/Lattigo/chebyshev.mlir @@ -0,0 +1,29 @@ +// RUN: heir-translate %s --emit-lattigo | FileCheck %s + +!evaluator = !lattigo.ckks.evaluator +!params = !lattigo.ckks.parameter +!eval = !lattigo.ckks.polynomial_evaluator +!ct = !lattigo.rlwe.ciphertext + +module attributes {scheme.ckks} { + // CHECK: func chebyshev_custom_domain + // CHECK: [[out:ct[0-9]+]]_polyCoeffs := []*big.Float{ + // CHECK: [[out]]_interval := [2]float64{-2.000000e+00, 2.000000e+00} + // CHECK: [[out]]_bignumPoly := bignum.NewPolynomial(bignum.Chebyshev, [[out]]_polyCoeffs, [[out]]_interval) + // CHECK: [[out]], {{.*}} := {{.*}}.Evaluate( + func.func @chebyshev_custom_domain(%params: !params, %evaluator: !evaluator, %ct: !ct) -> !ct { + %eval = lattigo.ckks.new_polynomial_evaluator %params, %evaluator : (!params, !evaluator) -> !eval + %0 = lattigo.ckks.chebyshev %eval, %ct {coefficients = [1.0, 0.5], targetScale = 1073741824, domain = array} : (!eval, !ct) -> !ct + return %0 : !ct + } + + // An unset domain passes nil, without emitting an interval. + // CHECK: func chebyshev_unset_domain + // CHECK-NOT: _interval := + // CHECK: [[out:ct[0-9]+]]_bignumPoly := bignum.NewPolynomial(bignum.Chebyshev, [[out]]_polyCoeffs, nil) + func.func @chebyshev_unset_domain(%params: !params, %evaluator: !evaluator, %ct: !ct) -> !ct { + %eval = lattigo.ckks.new_polynomial_evaluator %params, %evaluator : (!params, !evaluator) -> !eval + %0 = lattigo.ckks.chebyshev %eval, %ct {coefficients = [1.0, 0.5], targetScale = 1073741824} : (!eval, !ct) -> !ct + return %0 : !ct + } +}