From c6ac31ab584f3c9236a4fec119234f89fc56a7ec Mon Sep 17 00:00:00 2001 From: Marc Desgroseilliers Date: Fri, 29 May 2026 14:35:35 +0200 Subject: [PATCH] Handle dense resources --- lib/Target/Lattigo/LattigoEmitter.cpp | 111 +++++++++++------- .../Lattigo/memref_global_dense_resource.mlir | 44 +++++++ 2 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 tests/Emitter/Lattigo/memref_global_dense_resource.mlir diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index e20757e3dc..b871036a0f 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -38,6 +38,7 @@ #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project #include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/include/mlir/IR/AsmState.h" // from @llvm-project #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -46,15 +47,16 @@ #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/include/mlir/IR/IntegerSet.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectResourceBlobManager.h" // from @llvm-project +#include "mlir/include/mlir/IR/IntegerSet.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project namespace mlir { @@ -1112,10 +1114,13 @@ LogicalResult LattigoEmitter::printOperation(memref::GlobalOp op) { return op.emitError("memref.global must have an initial value"); } - if (auto denseAttr = dyn_cast(initAttr)) { + Type eltType = type.getElementType(); + + // Emits an inline DenseElementsAttr as a Go slice literal, or a slices.Repeat + // call for splats. + auto emitDense = [&](DenseElementsAttr denseAttr) -> LogicalResult { if (denseAttr.isSplat()) { imports.insert(std::string(kSlicesImport)); - auto eltType = type.getElementType(); std::string valStr; if (eltType.isF32() || eltType.isF64()) { FloatAttr splatAttr = denseAttr.getSplatValue(); @@ -1134,42 +1139,64 @@ LogicalResult LattigoEmitter::printOperation(memref::GlobalOp op) { for (int64_t dim : type.getShape()) { size *= dim; } - os << "slices.Repeat([]" << eltTypeStr.value() << "{" << valStr << "}, " << size << ")\n"; - } else { - // Non-splat, print as list - os << "[]" << eltTypeStr.value() << "{"; - bool first = true; - auto eltType = type.getElementType(); - if (eltType.isF32() || eltType.isF64()) { - for (FloatAttr val : denseAttr.getValues()) { - if (!first) os << ", "; - os << val.getValueAsDouble(); - first = false; - } - } else if (eltType.isInteger(1)) { - for (auto attr : denseAttr.getValues()) { - if (!first) os << ", "; - os << (cast(attr).getInt() != 0 ? "true" : "false"); - first = false; - } - } else if (llvm::isa(eltType)) { - for (auto attr : denseAttr.getValues()) { - if (!first) os << ", "; - os << cast(attr).getInt(); - first = false; - } - } else { - return op.emitError("Unsupported element type for dense attribute"); + return success(); + } + + // Non-splat, print as list + os << "[]" << eltTypeStr.value() << "{"; + bool first = true; + if (eltType.isF32() || eltType.isF64()) { + for (FloatAttr val : denseAttr.getValues()) { + if (!first) os << ", "; + os << val.getValueAsDouble(); + first = false; + } + } else if (eltType.isInteger(1)) { + for (auto attr : denseAttr.getValues()) { + if (!first) os << ", "; + os << (cast(attr).getInt() != 0 ? "true" : "false"); + first = false; + } + } else if (llvm::isa(eltType)) { + for (auto attr : denseAttr.getValues()) { + if (!first) os << ", "; + os << cast(attr).getInt(); + first = false; } - os << "}\n"; + } else { + return op.emitError("Unsupported element type for dense attribute"); } - } else { - return op.emitError( - "Only DenseElementsAttr is supported for memref.global"); + os << "}\n"; + return success(); + }; + + if (auto denseAttr = dyn_cast(initAttr)) { + return emitDense(denseAttr); } - return success(); + + // handle resource-backed data + if (auto resAttr = dyn_cast(initAttr)) { + AsmResourceBlob* blob = resAttr.getRawHandle().getBlob(); + if (!blob) { + return op.emitError( + "dense_resource for memref.global has no associated data"); + } + + ArrayRef raw = blob->getData(); + auto tensorTy = RankedTensorType::get(type.getShape(), eltType); + + if (!DenseElementsAttr::isValidRawBuffer(tensorTy, raw)) { + return op.emitError( + "dense_resource blob is not a valid raw buffer for its element type"); + } + return emitDense(DenseElementsAttr::getFromRawBuffer(tensorTy, raw)); + } + + return op.emitError( + "Only DenseElementsAttr and DenseResourceElementsAttr are supported " + "for memref.global"); } LogicalResult LattigoEmitter::printOperation(memref::GetGlobalOp op) { diff --git a/tests/Emitter/Lattigo/memref_global_dense_resource.mlir b/tests/Emitter/Lattigo/memref_global_dense_resource.mlir new file mode 100644 index 0000000000..356442a13c --- /dev/null +++ b/tests/Emitter/Lattigo/memref_global_dense_resource.mlir @@ -0,0 +1,44 @@ +// RUN: heir-translate %s --emit-lattigo | FileCheck %s + +// Resource-backed globals (dense_resource<...>) store their element bytes +// out-of-line in the dialect_resources blob. The emitter must read the blob +// and print the elements inline, just like an inline dense<...> attribute. + +module attributes {backend.lattigo, scheme.ckks} { + +// CHECK: var __constant_1x3xf32 = []float32{-1.697694e-01, -3.044266e-01, -5.597579e-01} +memref.global "private" constant @__constant_1x3xf32 : memref<1x3xf32> = dense_resource {alignment = 64 : i64} + +// CHECK: var __constant_3xi8 = []int8{1, -2, 127} +memref.global "private" constant @__constant_3xi8 : memref<3xi8> = dense_resource + +// CHECK: var __constant_2xi16 = []int16{256, -1} +memref.global "private" constant @__constant_2xi16 : memref<2xi16> = dense_resource + +// CHECK: var __constant_2xi64 = []int64{1, -1} +memref.global "private" constant @__constant_2xi64 : memref<2xi64> = dense_resource + +// i1 elements are stored one byte each and emitted as Go bool literals. +// CHECK: var __constant_3xi1 = []bool{true, false, true} +memref.global "private" constant @__constant_3xi1 : memref<3xi1> = dense_resource + +// CHECK: func test_global() ([]float32) { +func.func @test_global() -> memref<1x3xf32> { + // CHECK: v{{.*}} := __constant_1x3xf32 + %global = memref.get_global @__constant_1x3xf32 : memref<1x3xf32> + return %global : memref<1x3xf32> +} + +} + +{-# + dialect_resources: { + builtin: { + torch_tensor_1_3_torch.float32: "0x0400000009D82DBECCDD9BBE4B4C0FBF", + blob_i8: "0x0100000001FE7F", + blob_i16: "0x020000000001FFFF", + blob_i64: "0x080000000100000000000000FFFFFFFFFFFFFFFF", + blob_i1: "0x01000000010001" + } + } +#-}