Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 69 additions & 42 deletions lib/Target/Lattigo/LattigoEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project
#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/include/mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand All @@ -46,15 +47,16 @@
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/IntegerSet.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectResourceBlobManager.h" // from @llvm-project
#include "mlir/include/mlir/IR/IntegerSet.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project

namespace mlir {
Expand Down Expand Up @@ -1112,10 +1114,13 @@ LogicalResult LattigoEmitter::printOperation(memref::GlobalOp op) {
return op.emitError("memref.global must have an initial value");
}

if (auto denseAttr = dyn_cast<DenseElementsAttr>(initAttr)) {
Type eltType = type.getElementType();

// Emits an inline DenseElementsAttr as a Go slice literal, or a slices.Repeat
// call for splats.
auto emitDense = [&](DenseElementsAttr denseAttr) -> LogicalResult {
if (denseAttr.isSplat()) {
imports.insert(std::string(kSlicesImport));
auto eltType = type.getElementType();
std::string valStr;
if (eltType.isF32() || eltType.isF64()) {
FloatAttr splatAttr = denseAttr.getSplatValue<FloatAttr>();
Expand All @@ -1134,42 +1139,64 @@ LogicalResult LattigoEmitter::printOperation(memref::GlobalOp op) {
for (int64_t dim : type.getShape()) {
size *= dim;
}

os << "slices.Repeat([]" << eltTypeStr.value() << "{" << valStr << "}, "
<< size << ")\n";
} else {
// Non-splat, print as list
os << "[]" << eltTypeStr.value() << "{";
bool first = true;
auto eltType = type.getElementType();
if (eltType.isF32() || eltType.isF64()) {
for (FloatAttr val : denseAttr.getValues<FloatAttr>()) {
if (!first) os << ", ";
os << val.getValueAsDouble();
first = false;
}
} else if (eltType.isInteger(1)) {
for (auto attr : denseAttr.getValues<Attribute>()) {
if (!first) os << ", ";
os << (cast<IntegerAttr>(attr).getInt() != 0 ? "true" : "false");
first = false;
}
} else if (llvm::isa<IntegerType>(eltType)) {
for (auto attr : denseAttr.getValues<Attribute>()) {
if (!first) os << ", ";
os << cast<IntegerAttr>(attr).getInt();
first = false;
}
} else {
return op.emitError("Unsupported element type for dense attribute");
return success();
}

// Non-splat, print as list
os << "[]" << eltTypeStr.value() << "{";
bool first = true;
if (eltType.isF32() || eltType.isF64()) {
for (FloatAttr val : denseAttr.getValues<FloatAttr>()) {
if (!first) os << ", ";
os << val.getValueAsDouble();
first = false;
}
} else if (eltType.isInteger(1)) {
for (auto attr : denseAttr.getValues<Attribute>()) {
if (!first) os << ", ";
os << (cast<IntegerAttr>(attr).getInt() != 0 ? "true" : "false");
first = false;
}
} else if (llvm::isa<IntegerType>(eltType)) {
for (auto attr : denseAttr.getValues<Attribute>()) {
if (!first) os << ", ";
os << cast<IntegerAttr>(attr).getInt();
first = false;
}
os << "}\n";
} else {
return op.emitError("Unsupported element type for dense attribute");
}
} else {
return op.emitError(
"Only DenseElementsAttr is supported for memref.global");
os << "}\n";
return success();
};

if (auto denseAttr = dyn_cast<DenseElementsAttr>(initAttr)) {
return emitDense(denseAttr);
}
return success();

// handle resource-backed data
if (auto resAttr = dyn_cast<DenseResourceElementsAttr>(initAttr)) {
AsmResourceBlob* blob = resAttr.getRawHandle().getBlob();
if (!blob) {
return op.emitError(
"dense_resource for memref.global has no associated data");
}

ArrayRef<char> raw = blob->getData();
auto tensorTy = RankedTensorType::get(type.getShape(), eltType);

if (!DenseElementsAttr::isValidRawBuffer(tensorTy, raw)) {
return op.emitError(
"dense_resource blob is not a valid raw buffer for its element type");
}
return emitDense(DenseElementsAttr::getFromRawBuffer(tensorTy, raw));
}

return op.emitError(
"Only DenseElementsAttr and DenseResourceElementsAttr are supported "
"for memref.global");
}

LogicalResult LattigoEmitter::printOperation(memref::GetGlobalOp op) {
Expand Down
44 changes: 44 additions & 0 deletions tests/Emitter/Lattigo/memref_global_dense_resource.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: heir-translate %s --emit-lattigo | FileCheck %s

// Resource-backed globals (dense_resource<...>) store their element bytes
// out-of-line in the dialect_resources blob. The emitter must read the blob
// and print the elements inline, just like an inline dense<...> attribute.

module attributes {backend.lattigo, scheme.ckks} {

// CHECK: var __constant_1x3xf32 = []float32{-1.697694e-01, -3.044266e-01, -5.597579e-01}
memref.global "private" constant @__constant_1x3xf32 : memref<1x3xf32> = dense_resource<torch_tensor_1_3_torch.float32> {alignment = 64 : i64}

// CHECK: var __constant_3xi8 = []int8{1, -2, 127}
memref.global "private" constant @__constant_3xi8 : memref<3xi8> = dense_resource<blob_i8>

// CHECK: var __constant_2xi16 = []int16{256, -1}
memref.global "private" constant @__constant_2xi16 : memref<2xi16> = dense_resource<blob_i16>

// CHECK: var __constant_2xi64 = []int64{1, -1}
memref.global "private" constant @__constant_2xi64 : memref<2xi64> = dense_resource<blob_i64>

// i1 elements are stored one byte each and emitted as Go bool literals.
// CHECK: var __constant_3xi1 = []bool{true, false, true}
memref.global "private" constant @__constant_3xi1 : memref<3xi1> = dense_resource<blob_i1>

// CHECK: func test_global() ([]float32) {
func.func @test_global() -> memref<1x3xf32> {
// CHECK: v{{.*}} := __constant_1x3xf32
%global = memref.get_global @__constant_1x3xf32 : memref<1x3xf32>
return %global : memref<1x3xf32>
}

}

{-#
dialect_resources: {
builtin: {
torch_tensor_1_3_torch.float32: "0x0400000009D82DBECCDD9BBE4B4C0FBF",
blob_i8: "0x0100000001FE7F",
blob_i16: "0x020000000001FFFF",
blob_i64: "0x080000000100000000000000FFFFFFFFFFFFFFFF",
blob_i1: "0x01000000010001"
}
}
#-}
Loading