Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,91 @@ struct RewriteTransposedMatvec
}
};

struct RewriteAvgPoolAsConv1D
: public OpRewritePattern<mlir::linalg::PoolingNcwSumOp> {
public:
RewriteAvgPoolAsConv1D(MLIRContext* context)
: OpRewritePattern<mlir::linalg::PoolingNcwSumOp>(context) {}

using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::linalg::PoolingNcwSumOp poolOp,
PatternRewriter& rewriter) const override {
auto inputTy = cast<RankedTensorType>(poolOp.getInputs()[0].getType());
auto filterTy = cast<RankedTensorType>(poolOp.getInputs()[1].getType());
auto outputTy = cast<RankedTensorType>(poolOp.getResultTypes()[0]);

auto c = inputTy.getDimSize(1);
auto eltTy = filterTy.getElementType();
auto kernelShape = SmallVector<int64_t>{c, c, filterTy.getDimSize(0)};
auto kernelTy = RankedTensorType::get(kernelShape, eltTy);

// Create kernel value attributes of ones and zeros for the filter.
Attribute zeroAttr = rewriter.getZeroAttr(eltTy);
Attribute oneAttr = rewriter.getOneAttr(eltTy);

// If there is a constant division following the sum pool, update the
// kernel of ones to be 1 / divValue. This is a common enough pattern since
// it represents an average pool.
Attribute avgAttr = oneAttr;
Value avgPoolOutput;
if (poolOp->hasOneUse()) {
auto divOp = dyn_cast<arith::DivFOp>(*poolOp->getUsers().begin());
if (divOp) {
OpOperand& use = *poolOp->getUses().begin();
if (auto constantAttr = dyn_cast<Attribute>(getAsOpFoldResult(
divOp->getOperand(1 - use.getOperandNumber())))) {
if (auto splatAttr = dyn_cast<SplatElementsAttr>(
cast<DenseElementsAttr>(constantAttr))) {
auto divValue = splatAttr.getSplatValue<APFloat>();
APFloat one = APFloat::getOne(divValue.getSemantics());
avgAttr = rewriter.getFloatAttr(eltTy, one / divValue);
avgPoolOutput = divOp.getResult();
}
}
}
}

// Build average pooling kernel as a special type of convolution. The kernel
// computes a window average, so it is a fixed constant
// (1 / divValue) where f == c and zeros where f != c (so each
// channel is averaged independently) and strides equal to the pooling
// sizes. See
// https://machinelearningmastery.com/pooling-layers-for-convolutional-neural-networks/
int64_t w = filterTy.getDimSize(0);
int64_t numElements = c * c * w;
SmallVector<Attribute> values(numElements, zeroAttr);

for (int64_t f_idx = 0; f_idx < c; ++f_idx) {
for (int64_t c_idx = 0; c_idx < c; ++c_idx) {
if (f_idx == c_idx) {
for (int64_t w_idx = 0; w_idx < w; ++w_idx) {
int64_t idx = f_idx * (c * w) + c_idx * w + w_idx;
values[idx] = avgAttr;
}
}
}
}

TypedAttr kernelVals = DenseElementsAttr::get(kernelTy, values);
auto kernel =
arith::ConstantOp::create(rewriter, poolOp.getLoc(), kernelVals);
Value conv = linalg::Conv1DNcwFcwOp::create(
rewriter, poolOp.getLoc(), outputTy,
ValueRange{poolOp.getInputs()[0], kernel},
ValueRange{poolOp.getOutputs()[0]}, poolOp.getStrides(),
poolOp.getDilations())
.getResult(0);

if (avgPoolOutput) {
rewriter.replaceAllUsesWith(avgPoolOutput, conv);
} else {
rewriter.replaceOp(poolOp, conv);
}
return success();
}
};

struct RewriteAvgPoolAsConv2D
: public OpRewritePattern<mlir::linalg::PoolingNchwSumOp> {
public:
Expand Down Expand Up @@ -740,7 +825,8 @@ struct LinalgCanonicalizations
FoldConstantBroadcast, FoldBroadcastExtractSlice,
LinalgMapToElementwise, LinalgGenericToElementwise,
BroadcastToExpandShape, RewriteTransposedVecmat,
RewriteTransposedMatvec, RewriteAvgPoolAsConv2D>(context);
RewriteTransposedMatvec, RewriteAvgPoolAsConv1D,
RewriteAvgPoolAsConv2D>(context);

// Run pattern matching and conversion
// TODO (#1221): Investigate whether folding (default: on) can be skipped
Expand Down
1 change: 0 additions & 1 deletion lib/Utils/Layout/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ presburger::IntegerRelation get2dConvRowInterchangeRelation(int64_t c,
// h' = hi * g + (ci % g**2) // g
// w' = wi * g + (ci % g)
// 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c'
// FIXME why are these interchanged???
int64_t hOut = h * g;
int64_t wOut = w * g;
int64_t cOut = c / (g * g);
Expand Down
19 changes: 18 additions & 1 deletion lib/Utils/Layout/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,9 @@ presburger::IntegerRelation collapseDimensions(
if (associationGroup.size() == 1) {
continue;
}
for (int64_t reassocDim : associationGroup) {
// Iterate starting from the largest index so that earlier deletion do not
// impact later indices
for (int64_t reassocDim : llvm::reverse(associationGroup)) {
if (sourceType.getShape()[reassocDim] == 1) {
// Drop this unit dimension
clonedRelation->setAndEliminate(reassocDim, 0);
Expand All @@ -596,6 +598,17 @@ presburger::IntegerRelation expandDimensions(
// dimension we're adding is in the correct index of the integer relations
// domain variable list.
std::unique_ptr<IntegerRelation> clonedRelation = relation.clone();

// Handle the case where reassociation is empty
if (reassociation.empty()) {
for (int64_t i = 0; i < resultType.getRank(); ++i) {
auto newDimIndex = clonedRelation->insertVar(VarKind::Domain, i);
clonedRelation->addBound(BoundType::LB, newDimIndex, 0);
clonedRelation->addBound(BoundType::UB, newDimIndex, 0);
}
return *clonedRelation;
}

int oldDim = 0;
DenseMap<AffineExpr, AffineExpr> oldDimsToNewDims;
for (const ReassociationIndices& associationGroup : reassociation) {
Expand All @@ -618,6 +631,10 @@ presburger::IntegerRelation expandDimensions(
}
}
}
assert(static_cast<int64_t>(clonedRelation->getNumDomainVars()) ==
resultType.getRank() &&
"expandDimensions: result relation domain rank must match the result "
"tensor rank");
return *clonedRelation;
}

Expand Down
39 changes: 39 additions & 0 deletions lib/Utils/Layout/UtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,45 @@ TEST(UtilsTest, TestGetCollapsedRelationUnitDims) {
EXPECT_EQ(actual, expected);
}

TEST(UtilsTest, TestExpandDimensionsFromRankZero) {
// tensor.expand_shape from tensor<f32> to tensor<1x1xf32> uses
// an empty reassociation array.
MLIRContext context;
auto rel = getIntegerRelationFromIslStr(
"{ [] -> [ct, slot] : ct = 0 and 0 <= slot <= 1023 }")
.value();
RankedTensorType resultType =
RankedTensorType::get({1, 1}, IndexType::get(&context));
SmallVector<ReassociationIndices> reassociation = {};
IntegerRelation expanded = expandDimensions(rel, resultType, reassociation);

EXPECT_EQ(expanded.getNumDomainVars(), 2u);
EXPECT_TRUE(expanded.containsPointNoLocal({0, 0, 0, 0}));
}

TEST(UtilsTest, TestCollapseDimensionsMultipleUnitDimsInGroup) {
MLIRContext context;
auto rel =
getIntegerRelationFromIslStr(
"{ [i0, i1, i2] -> [ct, slot] : i0 = 0 and i2 = 0 and ct = 0 and "
"(-i1 + slot) mod 4 = 0 and 0 <= i1 <= 2 and 0 <= slot <= 1023 }")
.value();
RankedTensorType sourceType =
RankedTensorType::get({1, 3, 1}, IndexType::get(&context));
SmallVector<ReassociationIndices> reassociation = {{0, 1, 2}};
IntegerRelation collapsed =
collapseDimensions(rel, sourceType, reassociation);

EXPECT_EQ(collapsed.getNumDomainVars(), 1);
EXPECT_EQ(collapsed.getNumRangeVars(), 2);
auto expected =
getIntegerRelationFromIslStr(
"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 4 = 0 and "
"0 <= i0 <= 2 and 0 <= slot <= 1023 }")
.value();
EXPECT_TRUE(collapsed.isEqual(expected));
}

TEST(UtilsTest, TestGetSliceInsertionRelation) {
MLIRContext context;
// Insert a 3x4 slice into a 2x1x3x4 matrix at (1, 0, 0, 0).
Expand Down
23 changes: 23 additions & 0 deletions tests/Examples/lattigo/ckks/pooling1d/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib")
load("@rules_go//go:def.bzl", "go_test")

package(default_applicable_licenses = ["@heir//:license"])

heir_lattigo_lib(
name = "pooling1d",
go_library_name = "pooling1d",
heir_opt_flags = [
"--annotate-module=backend=lattigo scheme=ckks",
"--torch-linalg-to-ckks=ciphertext-degree=4096 scaling-mod-bits=55 first-mod-bits=60 split-preprocessing=1",
"--scheme-to-lattigo",
],
mlir_src = "pooling1d.mlir",
split_preprocessing = True,
)

go_test(
name = "pooling1d_test",
srcs = ["pooling1d_test.go"],
embed = [":pooling1d"],
deps = [":pooling1d_utils"],
)
20 changes: 20 additions & 0 deletions tests/Examples/lattigo/ckks/pooling1d/pooling1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d1)>

module {
func.func @pooling1d(%arg0: tensor<1x4x28xf32> {secret.secret}) -> tensor<1x4x14xf32> {
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant 2.000000e+00 : f32
%3 = tensor.empty() : tensor<1x4x14xf32>
%4 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1x4x14xf32>) -> tensor<1x4x14xf32>
%5 = arith.constant dense<1.0> : tensor<2xf32>
%6 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%arg0, %5 : tensor<1x4x28xf32>, tensor<2xf32>) outs(%4 : tensor<1x4x14xf32>) -> tensor<1x4x14xf32>
%7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6 : tensor<1x4x14xf32>) outs(%3 : tensor<1x4x14xf32>) {
^bb0(%in: f32, %out: f32):
%32 = arith.divf %in, %cst_1 : f32
linalg.yield %32 : f32
} -> tensor<1x4x14xf32>
return %7 : tensor<1x4x14xf32>
}
}
73 changes: 73 additions & 0 deletions tests/Examples/lattigo/ckks/pooling1d/pooling1d_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package pooling1d

import (
"math"
"testing"
"time"

"tests/Examples/lattigo/ckks/pooling1d/pooling1d_utils"
)

func TestPooling(t *testing.T) {
evaluator, params, ecd, enc, dec := pooling1d__configure()

// Input: 1x4x28 = 112 elements
arg0 := make([]float32, 112)
for i := range arg0 {
arg0[i] = float32(1.0)
}

// Filter: 4x4x2 = 32 elements, all 0.5 for average pooling
arg1 := make([]float32, 32)
for f := 0; f < 4; f++ {
for c := 0; c < 4; c++ {
for wi := 0; wi < 2; wi++ {
idx := f*4*2 + c*2 + wi
if f == c {
arg1[idx] = 0.5
} else {
arg1[idx] = 0
}
}
}
}

// Expected output: 1x4x14 = 56 elements
expected := make([]float32, 56)
// Compute expected average pooling kernel-size=2 with stride 2
// N=1, C=4, W=28
// Output: N=1, F=4, Wo=14
// pooling[n, f, wo] = sum_{c, wi} input[n, c, wo*2+wi] * filter[f, c, wi]
for f := 0; f < 4; f++ {
for wo := 0; wo < 14; wo++ {
var sum float32
for c := 0; c < 4; c++ {
for wi := 0; wi < 2; wi++ {
inIdx := c*28 + wo*2 + wi
filterIdx := f*4*2 + c*2 + wi
sum += arg0[inIdx] * arg1[filterIdx]
}
}
expected[f*14+wo] = sum
}
}

ct0 := pooling1d__encrypt__arg0(evaluator, params, ecd, enc, arg0)

startPre := time.Now()
filterPlains := pooling1d_utils.Pooling1d__preprocessing(params, ecd)
t.Logf("Preprocessing took %s", time.Since(startPre))

start := time.Now()
resultCt := pooling1d__preprocessed(evaluator, params, ecd, ct0, filterPlains)
t.Logf("Pooling1d (preprocessed) took %s", time.Since(start))

result := pooling1d__decrypt__result0(evaluator, params, ecd, dec, resultCt)

errorThreshold := float64(0.01)
for i := range expected {
if math.Abs(float64(result[i]-expected[i])) > errorThreshold {
t.Errorf("Decryption error at index %d: %.4f != %.4f", i, result[i], expected[i])
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: heir-opt %s --convert-to-ciphertext-semantics | FileCheck %s

#layout_in = #tensor_ext.layout<"{ [i0, i1, i2] -> [ct, slot] : i0 = 0 and i2 = 0 and ct = 0 and (-i1 + slot) mod 4 = 0 and 0 <= i1 <= 2 and 0 <= slot <= 1023 }">
#layout_out = #tensor_ext.layout<"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 4 = 0 and 0 <= i0 <= 2 and 0 <= slot <= 1023 }">
module {
// CHECK: func.func @collapse_multiple_unit_dims
func.func @collapse_multiple_unit_dims(%arg0: !secret.secret<tensor<1x3x1xf32>> {tensor_ext.layout = #layout_in}) -> (!secret.secret<tensor<3xf32>> {tensor_ext.layout = #layout_out}) {
// CHECK: secret.generic
// CHECK-NEXT: ^body(%[[input0:.*]]: tensor<1x1024xf32>)
// CHECK: secret.yield %[[input0]]
%0 = secret.generic(%arg0: !secret.secret<tensor<1x3x1xf32>> {tensor_ext.layout = #layout_in}) {
^body(%input0: tensor<1x3x1xf32>):
%collapsed = tensor.collapse_shape %input0 [[0, 1, 2]] {tensor_ext.layout = #layout_out} : tensor<1x3x1xf32> into tensor<3xf32>
secret.yield %collapsed : tensor<3xf32>
} -> (!secret.secret<tensor<3xf32>> {tensor_ext.layout = #layout_out})
return %0 : !secret.secret<tensor<3xf32>>
}
}
32 changes: 32 additions & 0 deletions tests/Transforms/linalg_canonicalizations/average_pooling_1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: heir-opt --linalg-canonicalizations --split-input-file %s | FileCheck %s

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d1)>
module {
// CHECK: func.func @main
// CHECK-SAME: (%[[arg0:.*]]: tensor<1x6x28xf32>)
// CHECK: %[[out:.*]] = arith.constant dense<3.0
// CHECK: %[[divided_cst:.*]] = arith.constant
// CHECK-SAME: 5.000000e-01, 5.000000e-01
// CHECK-SAME: tensor<6x6x2xf32>
// CHECK: linalg.conv_1d_ncw_fcw
// CHECK-SAME: strides = dense<2> : vector<1xi64>
// CHECK-SAME: ins(%[[arg0]], %[[divided_cst]] : tensor<1x6x28xf32>, tensor<6x6x2xf32>) outs(%[[out]] : tensor<1x6x14xf32>)
// CHECK: return
func.func @main(%arg0: tensor<1x6x28xf32>) -> tensor<1x6x14xf32> {
%cst_0 = arith.constant 3.000000e+00 : f32
%cst_1 = arith.constant 2.000000e+00 : f32
%3 = tensor.empty() : tensor<1x6x14xf32>
%4 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1x6x14xf32>) -> tensor<1x6x14xf32>
// filter constant doesn't affect output. Choosing 3 avoids other optimizations interfering
%5 = arith.constant dense<6.0> : tensor<2xf32>
%6 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%arg0, %5 : tensor<1x6x28xf32>, tensor<2xf32>) outs(%4 : tensor<1x6x14xf32>) -> tensor<1x6x14xf32>
%7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6 : tensor<1x6x14xf32>) outs(%3 : tensor<1x6x14xf32>) {
^bb0(%in: f32, %out: f32):
%32 = arith.divf %in, %cst_1 : f32
linalg.yield %32 : f32
} -> tensor<1x6x14xf32>
return %7 : tensor<1x6x14xf32>
}
}
Loading