From 1af5ddeae3a2f00602f7d01762289a93f39b089b Mon Sep 17 00:00:00 2001 From: Marc Desgroseilliers Date: Wed, 13 May 2026 12:02:21 +0200 Subject: [PATCH] add support for pooling_ncw_sum --- .../LinalgCanonicalizations.cpp | 88 ++++++++++++++++++- lib/Utils/Layout/Convolution.cpp | 1 - lib/Utils/Layout/Utils.cpp | 19 +++- lib/Utils/Layout/UtilsTest.cpp | 39 ++++++++ tests/Examples/lattigo/ckks/pooling1d/BUILD | 23 +++++ .../lattigo/ckks/pooling1d/pooling1d.mlir | 20 +++++ .../lattigo/ckks/pooling1d/pooling1d_test.go | 73 +++++++++++++++ .../collapse_shape_multiple_unit_dims.mlir | 18 ++++ .../average_pooling_1d.mlir | 32 +++++++ 9 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 tests/Examples/lattigo/ckks/pooling1d/BUILD create mode 100644 tests/Examples/lattigo/ckks/pooling1d/pooling1d.mlir create mode 100644 tests/Examples/lattigo/ckks/pooling1d/pooling1d_test.go create mode 100644 tests/Transforms/convert_to_ciphertext_semantics/collapse_shape_multiple_unit_dims.mlir create mode 100644 tests/Transforms/linalg_canonicalizations/average_pooling_1d.mlir diff --git a/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp b/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp index 5adcb1f628..7b4d2aa386 100644 --- a/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp +++ b/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp @@ -639,6 +639,91 @@ struct RewriteTransposedMatvec } }; +struct RewriteAvgPoolAsConv1D + : public OpRewritePattern { + public: + RewriteAvgPoolAsConv1D(MLIRContext* context) + : OpRewritePattern(context) {} + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::linalg::PoolingNcwSumOp poolOp, + PatternRewriter& rewriter) const override { + auto inputTy = cast(poolOp.getInputs()[0].getType()); + auto filterTy = cast(poolOp.getInputs()[1].getType()); + auto outputTy = cast(poolOp.getResultTypes()[0]); + + auto c = inputTy.getDimSize(1); + auto eltTy = filterTy.getElementType(); + auto kernelShape = SmallVector{c, c, filterTy.getDimSize(0)}; + auto kernelTy = RankedTensorType::get(kernelShape, eltTy); + + // Create kernel value attributes of ones and zeros for the filter. + Attribute zeroAttr = rewriter.getZeroAttr(eltTy); + Attribute oneAttr = rewriter.getOneAttr(eltTy); + + // If there is a constant division following the sum pool, update the + // kernel of ones to be 1 / divValue. This is a common enough pattern since + // it represents an average pool. + Attribute avgAttr = oneAttr; + Value avgPoolOutput; + if (poolOp->hasOneUse()) { + auto divOp = dyn_cast(*poolOp->getUsers().begin()); + if (divOp) { + OpOperand& use = *poolOp->getUses().begin(); + if (auto constantAttr = dyn_cast(getAsOpFoldResult( + divOp->getOperand(1 - use.getOperandNumber())))) { + if (auto splatAttr = dyn_cast( + cast(constantAttr))) { + auto divValue = splatAttr.getSplatValue(); + APFloat one = APFloat::getOne(divValue.getSemantics()); + avgAttr = rewriter.getFloatAttr(eltTy, one / divValue); + avgPoolOutput = divOp.getResult(); + } + } + } + } + + // Build average pooling kernel as a special type of convolution. The kernel + // computes a window average, so it is a fixed constant + // (1 / divValue) where f == c and zeros where f != c (so each + // channel is averaged independently) and strides equal to the pooling + // sizes. See + // https://machinelearningmastery.com/pooling-layers-for-convolutional-neural-networks/ + int64_t w = filterTy.getDimSize(0); + int64_t numElements = c * c * w; + SmallVector values(numElements, zeroAttr); + + for (int64_t f_idx = 0; f_idx < c; ++f_idx) { + for (int64_t c_idx = 0; c_idx < c; ++c_idx) { + if (f_idx == c_idx) { + for (int64_t w_idx = 0; w_idx < w; ++w_idx) { + int64_t idx = f_idx * (c * w) + c_idx * w + w_idx; + values[idx] = avgAttr; + } + } + } + } + + TypedAttr kernelVals = DenseElementsAttr::get(kernelTy, values); + auto kernel = + arith::ConstantOp::create(rewriter, poolOp.getLoc(), kernelVals); + Value conv = linalg::Conv1DNcwFcwOp::create( + rewriter, poolOp.getLoc(), outputTy, + ValueRange{poolOp.getInputs()[0], kernel}, + ValueRange{poolOp.getOutputs()[0]}, poolOp.getStrides(), + poolOp.getDilations()) + .getResult(0); + + if (avgPoolOutput) { + rewriter.replaceAllUsesWith(avgPoolOutput, conv); + } else { + rewriter.replaceOp(poolOp, conv); + } + return success(); + } +}; + struct RewriteAvgPoolAsConv2D : public OpRewritePattern { public: @@ -740,7 +825,8 @@ struct LinalgCanonicalizations FoldConstantBroadcast, FoldBroadcastExtractSlice, LinalgMapToElementwise, LinalgGenericToElementwise, BroadcastToExpandShape, RewriteTransposedVecmat, - RewriteTransposedMatvec, RewriteAvgPoolAsConv2D>(context); + RewriteTransposedMatvec, RewriteAvgPoolAsConv1D, + RewriteAvgPoolAsConv2D>(context); // Run pattern matching and conversion // TODO (#1221): Investigate whether folding (default: on) can be skipped diff --git a/lib/Utils/Layout/Convolution.cpp b/lib/Utils/Layout/Convolution.cpp index 9c35a4381f..3ba71ba312 100644 --- a/lib/Utils/Layout/Convolution.cpp +++ b/lib/Utils/Layout/Convolution.cpp @@ -553,7 +553,6 @@ presburger::IntegerRelation get2dConvRowInterchangeRelation(int64_t c, // h' = hi * g + (ci % g**2) // g // w' = wi * g + (ci % g) // 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c' - // FIXME why are these interchanged??? int64_t hOut = h * g; int64_t wOut = w * g; int64_t cOut = c / (g * g); diff --git a/lib/Utils/Layout/Utils.cpp b/lib/Utils/Layout/Utils.cpp index 75d6c39a1a..7cc23a5566 100644 --- a/lib/Utils/Layout/Utils.cpp +++ b/lib/Utils/Layout/Utils.cpp @@ -576,7 +576,9 @@ presburger::IntegerRelation collapseDimensions( if (associationGroup.size() == 1) { continue; } - for (int64_t reassocDim : associationGroup) { + // Iterate starting from the largest index so that earlier deletion do not + // impact later indices + for (int64_t reassocDim : llvm::reverse(associationGroup)) { if (sourceType.getShape()[reassocDim] == 1) { // Drop this unit dimension clonedRelation->setAndEliminate(reassocDim, 0); @@ -596,6 +598,17 @@ presburger::IntegerRelation expandDimensions( // dimension we're adding is in the correct index of the integer relations // domain variable list. std::unique_ptr clonedRelation = relation.clone(); + + // Handle the case where reassociation is empty + if (reassociation.empty()) { + for (int64_t i = 0; i < resultType.getRank(); ++i) { + auto newDimIndex = clonedRelation->insertVar(VarKind::Domain, i); + clonedRelation->addBound(BoundType::LB, newDimIndex, 0); + clonedRelation->addBound(BoundType::UB, newDimIndex, 0); + } + return *clonedRelation; + } + int oldDim = 0; DenseMap oldDimsToNewDims; for (const ReassociationIndices& associationGroup : reassociation) { @@ -618,6 +631,10 @@ presburger::IntegerRelation expandDimensions( } } } + assert(static_cast(clonedRelation->getNumDomainVars()) == + resultType.getRank() && + "expandDimensions: result relation domain rank must match the result " + "tensor rank"); return *clonedRelation; } diff --git a/lib/Utils/Layout/UtilsTest.cpp b/lib/Utils/Layout/UtilsTest.cpp index afca528ce0..271fa8b8ea 100644 --- a/lib/Utils/Layout/UtilsTest.cpp +++ b/lib/Utils/Layout/UtilsTest.cpp @@ -451,6 +451,45 @@ TEST(UtilsTest, TestGetCollapsedRelationUnitDims) { EXPECT_EQ(actual, expected); } +TEST(UtilsTest, TestExpandDimensionsFromRankZero) { + // tensor.expand_shape from tensor to tensor<1x1xf32> uses + // an empty reassociation array. + MLIRContext context; + auto rel = getIntegerRelationFromIslStr( + "{ [] -> [ct, slot] : ct = 0 and 0 <= slot <= 1023 }") + .value(); + RankedTensorType resultType = + RankedTensorType::get({1, 1}, IndexType::get(&context)); + SmallVector reassociation = {}; + IntegerRelation expanded = expandDimensions(rel, resultType, reassociation); + + EXPECT_EQ(expanded.getNumDomainVars(), 2u); + EXPECT_TRUE(expanded.containsPointNoLocal({0, 0, 0, 0})); +} + +TEST(UtilsTest, TestCollapseDimensionsMultipleUnitDimsInGroup) { + MLIRContext context; + auto rel = + getIntegerRelationFromIslStr( + "{ [i0, i1, i2] -> [ct, slot] : i0 = 0 and i2 = 0 and ct = 0 and " + "(-i1 + slot) mod 4 = 0 and 0 <= i1 <= 2 and 0 <= slot <= 1023 }") + .value(); + RankedTensorType sourceType = + RankedTensorType::get({1, 3, 1}, IndexType::get(&context)); + SmallVector reassociation = {{0, 1, 2}}; + IntegerRelation collapsed = + collapseDimensions(rel, sourceType, reassociation); + + EXPECT_EQ(collapsed.getNumDomainVars(), 1); + EXPECT_EQ(collapsed.getNumRangeVars(), 2); + auto expected = + getIntegerRelationFromIslStr( + "{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 4 = 0 and " + "0 <= i0 <= 2 and 0 <= slot <= 1023 }") + .value(); + EXPECT_TRUE(collapsed.isEqual(expected)); +} + TEST(UtilsTest, TestGetSliceInsertionRelation) { MLIRContext context; // Insert a 3x4 slice into a 2x1x3x4 matrix at (1, 0, 0, 0). diff --git a/tests/Examples/lattigo/ckks/pooling1d/BUILD b/tests/Examples/lattigo/ckks/pooling1d/BUILD new file mode 100644 index 0000000000..4c20ee2d91 --- /dev/null +++ b/tests/Examples/lattigo/ckks/pooling1d/BUILD @@ -0,0 +1,23 @@ +load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib") +load("@rules_go//go:def.bzl", "go_test") + +package(default_applicable_licenses = ["@heir//:license"]) + +heir_lattigo_lib( + name = "pooling1d", + go_library_name = "pooling1d", + heir_opt_flags = [ + "--annotate-module=backend=lattigo scheme=ckks", + "--torch-linalg-to-ckks=ciphertext-degree=4096 scaling-mod-bits=55 first-mod-bits=60 split-preprocessing=1", + "--scheme-to-lattigo", + ], + mlir_src = "pooling1d.mlir", + split_preprocessing = True, +) + +go_test( + name = "pooling1d_test", + srcs = ["pooling1d_test.go"], + embed = [":pooling1d"], + deps = [":pooling1d_utils"], +) diff --git a/tests/Examples/lattigo/ckks/pooling1d/pooling1d.mlir b/tests/Examples/lattigo/ckks/pooling1d/pooling1d.mlir new file mode 100644 index 0000000000..01351427d5 --- /dev/null +++ b/tests/Examples/lattigo/ckks/pooling1d/pooling1d.mlir @@ -0,0 +1,20 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0, d1) -> (d1)> + +module { + func.func @pooling1d(%arg0: tensor<1x4x28xf32> {secret.secret}) -> tensor<1x4x14xf32> { + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant 2.000000e+00 : f32 + %3 = tensor.empty() : tensor<1x4x14xf32> + %4 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1x4x14xf32>) -> tensor<1x4x14xf32> + %5 = arith.constant dense<1.0> : tensor<2xf32> + %6 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%arg0, %5 : tensor<1x4x28xf32>, tensor<2xf32>) outs(%4 : tensor<1x4x14xf32>) -> tensor<1x4x14xf32> + %7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6 : tensor<1x4x14xf32>) outs(%3 : tensor<1x4x14xf32>) { + ^bb0(%in: f32, %out: f32): + %32 = arith.divf %in, %cst_1 : f32 + linalg.yield %32 : f32 + } -> tensor<1x4x14xf32> + return %7 : tensor<1x4x14xf32> + } +} diff --git a/tests/Examples/lattigo/ckks/pooling1d/pooling1d_test.go b/tests/Examples/lattigo/ckks/pooling1d/pooling1d_test.go new file mode 100644 index 0000000000..57f99bf86a --- /dev/null +++ b/tests/Examples/lattigo/ckks/pooling1d/pooling1d_test.go @@ -0,0 +1,73 @@ +package pooling1d + +import ( + "math" + "testing" + "time" + + "tests/Examples/lattigo/ckks/pooling1d/pooling1d_utils" +) + +func TestPooling(t *testing.T) { + evaluator, params, ecd, enc, dec := pooling1d__configure() + + // Input: 1x4x28 = 112 elements + arg0 := make([]float32, 112) + for i := range arg0 { + arg0[i] = float32(1.0) + } + + // Filter: 4x4x2 = 32 elements, all 0.5 for average pooling + arg1 := make([]float32, 32) + for f := 0; f < 4; f++ { + for c := 0; c < 4; c++ { + for wi := 0; wi < 2; wi++ { + idx := f*4*2 + c*2 + wi + if f == c { + arg1[idx] = 0.5 + } else { + arg1[idx] = 0 + } + } + } + } + + // Expected output: 1x4x14 = 56 elements + expected := make([]float32, 56) + // Compute expected average pooling kernel-size=2 with stride 2 + // N=1, C=4, W=28 + // Output: N=1, F=4, Wo=14 + // pooling[n, f, wo] = sum_{c, wi} input[n, c, wo*2+wi] * filter[f, c, wi] + for f := 0; f < 4; f++ { + for wo := 0; wo < 14; wo++ { + var sum float32 + for c := 0; c < 4; c++ { + for wi := 0; wi < 2; wi++ { + inIdx := c*28 + wo*2 + wi + filterIdx := f*4*2 + c*2 + wi + sum += arg0[inIdx] * arg1[filterIdx] + } + } + expected[f*14+wo] = sum + } + } + + ct0 := pooling1d__encrypt__arg0(evaluator, params, ecd, enc, arg0) + + startPre := time.Now() + filterPlains := pooling1d_utils.Pooling1d__preprocessing(params, ecd) + t.Logf("Preprocessing took %s", time.Since(startPre)) + + start := time.Now() + resultCt := pooling1d__preprocessed(evaluator, params, ecd, ct0, filterPlains) + t.Logf("Pooling1d (preprocessed) took %s", time.Since(start)) + + result := pooling1d__decrypt__result0(evaluator, params, ecd, dec, resultCt) + + errorThreshold := float64(0.01) + for i := range expected { + if math.Abs(float64(result[i]-expected[i])) > errorThreshold { + t.Errorf("Decryption error at index %d: %.4f != %.4f", i, result[i], expected[i]) + } + } +} diff --git a/tests/Transforms/convert_to_ciphertext_semantics/collapse_shape_multiple_unit_dims.mlir b/tests/Transforms/convert_to_ciphertext_semantics/collapse_shape_multiple_unit_dims.mlir new file mode 100644 index 0000000000..d797b9ddad --- /dev/null +++ b/tests/Transforms/convert_to_ciphertext_semantics/collapse_shape_multiple_unit_dims.mlir @@ -0,0 +1,18 @@ +// RUN: heir-opt %s --convert-to-ciphertext-semantics | FileCheck %s + +#layout_in = #tensor_ext.layout<"{ [i0, i1, i2] -> [ct, slot] : i0 = 0 and i2 = 0 and ct = 0 and (-i1 + slot) mod 4 = 0 and 0 <= i1 <= 2 and 0 <= slot <= 1023 }"> +#layout_out = #tensor_ext.layout<"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 4 = 0 and 0 <= i0 <= 2 and 0 <= slot <= 1023 }"> +module { + // CHECK: func.func @collapse_multiple_unit_dims + func.func @collapse_multiple_unit_dims(%arg0: !secret.secret> {tensor_ext.layout = #layout_in}) -> (!secret.secret> {tensor_ext.layout = #layout_out}) { + // CHECK: secret.generic + // CHECK-NEXT: ^body(%[[input0:.*]]: tensor<1x1024xf32>) + // CHECK: secret.yield %[[input0]] + %0 = secret.generic(%arg0: !secret.secret> {tensor_ext.layout = #layout_in}) { + ^body(%input0: tensor<1x3x1xf32>): + %collapsed = tensor.collapse_shape %input0 [[0, 1, 2]] {tensor_ext.layout = #layout_out} : tensor<1x3x1xf32> into tensor<3xf32> + secret.yield %collapsed : tensor<3xf32> + } -> (!secret.secret> {tensor_ext.layout = #layout_out}) + return %0 : !secret.secret> + } +} diff --git a/tests/Transforms/linalg_canonicalizations/average_pooling_1d.mlir b/tests/Transforms/linalg_canonicalizations/average_pooling_1d.mlir new file mode 100644 index 0000000000..291dd5a75e --- /dev/null +++ b/tests/Transforms/linalg_canonicalizations/average_pooling_1d.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --linalg-canonicalizations --split-input-file %s | FileCheck %s + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0, d1) -> (d1)> +module { + // CHECK: func.func @main + // CHECK-SAME: (%[[arg0:.*]]: tensor<1x6x28xf32>) + // CHECK: %[[out:.*]] = arith.constant dense<3.0 + // CHECK: %[[divided_cst:.*]] = arith.constant + // CHECK-SAME: 5.000000e-01, 5.000000e-01 + // CHECK-SAME: tensor<6x6x2xf32> + // CHECK: linalg.conv_1d_ncw_fcw + // CHECK-SAME: strides = dense<2> : vector<1xi64> + // CHECK-SAME: ins(%[[arg0]], %[[divided_cst]] : tensor<1x6x28xf32>, tensor<6x6x2xf32>) outs(%[[out]] : tensor<1x6x14xf32>) + // CHECK: return + func.func @main(%arg0: tensor<1x6x28xf32>) -> tensor<1x6x14xf32> { + %cst_0 = arith.constant 3.000000e+00 : f32 + %cst_1 = arith.constant 2.000000e+00 : f32 + %3 = tensor.empty() : tensor<1x6x14xf32> + %4 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1x6x14xf32>) -> tensor<1x6x14xf32> + // filter constant doesn't affect output. Choosing 3 avoids other optimizations interfering + %5 = arith.constant dense<6.0> : tensor<2xf32> + %6 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%arg0, %5 : tensor<1x6x28xf32>, tensor<2xf32>) outs(%4 : tensor<1x6x14xf32>) -> tensor<1x6x14xf32> + %7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6 : tensor<1x6x14xf32>) outs(%3 : tensor<1x6x14xf32>) { + ^bb0(%in: f32, %out: f32): + %32 = arith.divf %in, %cst_1 : f32 + linalg.yield %32 : f32 + } -> tensor<1x6x14xf32> + return %7 : tensor<1x6x14xf32> + } +}