Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions lib/Dialect/Cheddar/IR/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Cheddar dialect implementation

load("@heir//lib/Dialect:dialect.bzl", "add_heir_dialect_library")
load("@llvm-project//mlir:tblgen.bzl", "td_library")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Dialect",
srcs = [
"CheddarDialect.cpp",
],
hdrs = [
"CheddarDialect.h",
"CheddarOps.h",
"CheddarTypes.h",
],
deps = [
":CheddarOps",
":CheddarTypes",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
],
)

cc_library(
name = "CheddarTypes",
srcs = [
"CheddarTypes.cpp",
],
hdrs = [
"CheddarDialect.h",
"CheddarTypes.h",
],
deps = [
":dialect_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "CheddarOps",
srcs = [
"CheddarOps.cpp",
],
hdrs = [
"CheddarDialect.h",
"CheddarOps.h",
"CheddarTypes.h",
],
deps = [
":CheddarTypes",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@heir//lib/Utils",
"@heir//lib/Utils:RotationUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
],
)

td_library(
name = "td_files",
srcs = [
"CheddarDialect.td",
"CheddarOps.td",
"CheddarTypes.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//lib/Dialect:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)

add_heir_dialect_library(
name = "dialect_inc_gen",
dialect = "Cheddar",
kind = "dialect",
td_file = "CheddarDialect.td",
deps = [
":td_files",
],
)

add_heir_dialect_library(
name = "types_inc_gen",
dialect = "Cheddar",
kind = "type",
td_file = "CheddarTypes.td",
deps = [
":td_files",
],
)

add_heir_dialect_library(
name = "ops_inc_gen",
dialect = "Cheddar",
kind = "op",
td_file = "CheddarOps.td",
deps = [
":td_files",
"@heir//lib/Dialect:td_files",
],
)
39 changes: 39 additions & 0 deletions lib/Dialect/Cheddar/IR/CheddarDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "lib/Dialect/Cheddar/IR/CheddarDialect.h"

#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// NOLINTNEXTLINE(misc-include-cleaner): Required to define CheddarOps

#include "lib/Dialect/Cheddar/IR/CheddarOps.h"
#include "lib/Dialect/Cheddar/IR/CheddarTypes.h"

// Generated definitions
#include "lib/Dialect/Cheddar/IR/CheddarDialect.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/Cheddar/IR/CheddarTypes.cpp.inc"

#define GET_OP_CLASSES
#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc"

namespace mlir {
namespace heir {
namespace cheddar {

void CheddarDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "lib/Dialect/Cheddar/IR/CheddarTypes.cpp.inc"
>();

addOperations<
#define GET_OP_LIST
#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc"
>();
}

} // namespace cheddar
} // namespace heir
} // namespace mlir
10 changes: 10 additions & 0 deletions lib/Dialect/Cheddar/IR/CheddarDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_
#define LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "lib/Dialect/Cheddar/IR/CheddarDialect.h.inc"

#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_
24 changes: 24 additions & 0 deletions lib/Dialect/Cheddar/IR/CheddarDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_
#define LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_

include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"

def Cheddar_Dialect : Dialect {
let name = "cheddar";
let description = [{
The `cheddar` dialect is an exit dialect for generating C++ code against the
CHEDDAR GPU FHE library API.

CHEDDAR is a CKKS-only GPU-accelerated FHE library. It supports both 32-bit
and 64-bit word types, with 32-bit being the primary fast path on GPUs.

See [the Cheddar GitHub repository](https://github.com/scale-snu/cheddar-fhe)
}];

let cppNamespace = "::mlir::heir::cheddar";

let useDefaultTypePrinterParser = 1;
}

#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_
46 changes: 46 additions & 0 deletions lib/Dialect/Cheddar/IR/CheddarOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "lib/Dialect/Cheddar/IR/CheddarOps.h"

#include "lib/Dialect/Cheddar/IR/CheddarTypes.h"
#include "lib/Utils/RotationUtils.h"
#include "lib/Utils/Utils.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace cheddar {

::llvm::SmallVector<::mlir::OpFoldResult> HRotOp::getRotationIndices() {
if (getStaticDistance()) return {getStaticDistanceAttr()};
return {getDynamicDistance()};
}

LogicalResult HRotOp::verify() {
return containsExactlyOneOrEmitError(getOperation(), getDynamicDistance(),
getStaticDistance());
}

::llvm::SmallVector<::mlir::OpFoldResult> HRotAddOp::getRotationIndices() {
return {getDistanceAttr()};
}

::llvm::SmallVector<::mlir::OpFoldResult>
LinearTransformOp::getRotationIndices() {
auto diagonalsType = cast<RankedTensorType>(getDiagonals().getType());
int64_t slots = diagonalsType.getShape()[1];
int64_t logBSGS = getLogBabyStepGiantStepRatio().getInt();
auto rotations = lintransRotationIndices(
getDiagonalIndicesAttr().asArrayRef(), slots, logBSGS);
SmallVector<OpFoldResult> result;
result.reserve(rotations.size());
auto* mlirCtx = (*this)->getContext();
for (int64_t rot : rotations) {
result.push_back(IntegerAttr::get(IndexType::get(mlirCtx), rot));
}
return result;
}

} // namespace cheddar
} // namespace heir
} // namespace mlir
15 changes: 15 additions & 0 deletions lib/Dialect/Cheddar/IR/CheddarOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_
#define LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_

// IWYU pragma: begin_keep
#include "lib/Dialect/Cheddar/IR/CheddarDialect.h"
#include "lib/Dialect/Cheddar/IR/CheddarTypes.h"
#include "lib/Dialect/HEIRInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
// IWYU pragma: end_keep

#define GET_OP_CLASSES
#include "lib/Dialect/Cheddar/IR/CheddarOps.h.inc"

#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_
Loading
Loading