From 8ef69684ba1c1d4ce34a7ed1c080115690fe5ca3 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 28 May 2026 13:07:07 -0700 Subject: [PATCH] fuse batchnorm1d into a preceding linear layer PiperOrigin-RevId: 922932512 --- .../ArithmeticPipelineRegistration.cpp | 2 + lib/Pipelines/BUILD | 1 + lib/Target/Lattigo/LattigoEmitter.cpp | 14 +- lib/Target/Lattigo/LattigoEmitter.h | 1 + lib/Transforms/LinalgCanonicalizations/BUILD | 1 + .../LinalgCanonicalizations.cpp | 192 ++++++++++++- lib/Transforms/LinalgFuseLinearOps/BUILD | 33 +++ .../LinalgFuseLinearOps.cpp | 261 ++++++++++++++++++ .../LinalgFuseLinearOps/LinalgFuseLinearOps.h | 20 ++ .../LinalgFuseLinearOps.td | 31 +++ tests/Examples/common/batchnorm1d.mlir | 66 +++++ tests/Examples/lattigo/ckks/batchnorm1d/BUILD | 22 ++ .../ckks/batchnorm1d/batchnorm1d_test.go | 25 ++ .../lattigo/ckks/matvec_batchnorm/BUILD | 23 ++ .../matvec_batchnorm/matvec_batchnorm.mlir | 32 +++ .../matvec_batchnorm/matvec_batchnorm_test.go | 25 ++ .../linalg_canonicalizations/batchnorm1d.mlir | 32 +++ .../materialize_broadcast.mlir | 75 +++++ tests/Transforms/linalg_fuse_linear_ops/BUILD | 10 + .../linalg_fuse_linear_ops/doctest.mlir | 21 ++ .../linalg_fuse_linear_ops.mlir | 101 +++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 23 files changed, 984 insertions(+), 7 deletions(-) create mode 100644 lib/Transforms/LinalgFuseLinearOps/BUILD create mode 100644 lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp create mode 100644 lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h create mode 100644 lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.td create mode 100644 tests/Examples/common/batchnorm1d.mlir create mode 100644 tests/Examples/lattigo/ckks/batchnorm1d/BUILD create mode 100644 tests/Examples/lattigo/ckks/batchnorm1d/batchnorm1d_test.go create mode 100644 tests/Examples/lattigo/ckks/matvec_batchnorm/BUILD create mode 100644 tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm.mlir create mode 100644 tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm_test.go create mode 100644 tests/Transforms/linalg_canonicalizations/batchnorm1d.mlir create mode 100644 tests/Transforms/linalg_canonicalizations/materialize_broadcast.mlir create mode 100644 tests/Transforms/linalg_fuse_linear_ops/BUILD create mode 100644 tests/Transforms/linalg_fuse_linear_ops/doctest.mlir create mode 100644 tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index 5d7d058bb4..1fefe5e274 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -49,6 +49,7 @@ #include "lib/Transforms/LayoutOptimization/LayoutOptimization.h" #include "lib/Transforms/LayoutPropagation/LayoutPropagation.h" #include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h" +#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h" #include "lib/Transforms/OperationBalancer/OperationBalancer.h" #include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h" #include "lib/Transforms/PopulateScale/PopulateScale.h" @@ -541,6 +542,7 @@ void linalgPreprocessingBuilder(OpPassManager& manager) { manager.addPass(createInlineActivations()); manager.addPass(createActivationCanonicalizations()); manager.addPass(createLinalgCanonicalizations()); + manager.addPass(createLinalgFuseLinearOpsPass()); manager.addPass(createDropUnitDims()); manager.addPass(createFoldConstantTensors()); manager.addPass(createCanonicalizerPass()); diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index bcfa650208..6985c58566 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -147,6 +147,7 @@ cc_library( "@heir//lib/Transforms/LayoutOptimization", "@heir//lib/Transforms/LayoutPropagation", "@heir//lib/Transforms/LinalgCanonicalizations", + "@heir//lib/Transforms/LinalgFuseLinearOps", "@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration", "@heir//lib/Transforms/OperationBalancer", "@heir//lib/Transforms/OptimizeRelinearization", diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index e20757e3dc..16d5aad61f 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -106,11 +106,11 @@ LogicalResult LattigoEmitter::translate(Operation& op) { // Arith ops .Case( + arith::TruncFOp, arith::RemSIOp, arith::AddIOp, arith::AddFOp, + arith::AndIOp, arith::SubIOp, arith::SubFOp, arith::MaxSIOp, + arith::MinSIOp, arith::MulIOp, arith::MulFOp, arith::DivSIOp, + arith::DivFOp, arith::NegFOp, arith::OrIOp, arith::XOrIOp, + arith::CmpIOp, arith::CmpFOp, arith::SelectOp>( [&](auto op) { return printOperation(op); }) // SCF ops .Case( @@ -709,6 +709,10 @@ LogicalResult LattigoEmitter::printOperation(arith::ExtFOp op) { return typecast(op.getOperand(), op.getResult()); } +LogicalResult LattigoEmitter::printOperation(arith::TruncFOp op) { + return typecast(op.getOperand(), op.getResult()); +} + LogicalResult LattigoEmitter::printOperation(arith::FloorDivSIOp op) { imports.insert(std::string(kMathImport)); diff --git a/lib/Target/Lattigo/LattigoEmitter.h b/lib/Target/Lattigo/LattigoEmitter.h index 84b0d7c30b..e9de5b454b 100644 --- a/lib/Target/Lattigo/LattigoEmitter.h +++ b/lib/Target/Lattigo/LattigoEmitter.h @@ -115,6 +115,7 @@ class LattigoEmitter { LogicalResult printOperation(::mlir::arith::ConstantOp op); LogicalResult printOperation(::mlir::arith::DivSIOp op); LogicalResult printOperation(::mlir::arith::DivFOp op); + LogicalResult printOperation(::mlir::arith::TruncFOp op); LogicalResult printOperation(::mlir::arith::ExtFOp op); LogicalResult printOperation(::mlir::arith::ExtSIOp op); LogicalResult printOperation(::mlir::arith::ExtUIOp op); diff --git a/lib/Transforms/LinalgCanonicalizations/BUILD b/lib/Transforms/LinalgCanonicalizations/BUILD index f07c1547ae..dc0f81d84e 100644 --- a/lib/Transforms/LinalgCanonicalizations/BUILD +++ b/lib/Transforms/LinalgCanonicalizations/BUILD @@ -18,6 +18,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", diff --git a/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp b/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp index 5adcb1f628..d9dad798eb 100644 --- a/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp +++ b/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp @@ -5,18 +5,21 @@ #include #include "lib/Utils/TensorUtils.h" +#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallBitVector.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -729,6 +732,190 @@ struct RewriteAvgPoolAsConv2D } }; +static SmallVector getBroadcastDimensions(AffineMap map, + int64_t numDims) { + llvm::SmallDenseSet usedDims; + for (auto expr : map.getResults()) { + if (auto dimExpr = dyn_cast(expr)) { + usedDims.insert(dimExpr.getPosition()); + } + } + SmallVector addedDims; + for (int i = 0; i < numDims; ++i) { + if (usedDims.find(i) == usedDims.end()) { + addedDims.push_back(i); + } + } + return addedDims; +} + +/// A rewrite pattern that materializes broadcasts for broadcasting operands in +/// linalg.generic ops with parallel iterators. +/// +/// This pattern matches linalg.generic ops where all iterator types are +/// parallel, and at least one operand has a broadcasting indexing map (i.e., +/// the map drops dimensions, mapping a larger iteration space to a smaller +/// operand space). It creates a linalg.broadcast op for each such operand to +/// materialize the broadcast, making the operand match the output shape. +/// This allows subsequent patterns (like LinalgGenericToElementwise) to convert +/// the op to elementwise operations. +struct MaterializeBroadcasts : public OpRewritePattern { + public: + MaterializeBroadcasts(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter& rewriter) const override { + // Only handle ops with multiple inputs to avoid infinite loops when + // materializing broadcasts. Single-input broadcast ops don't need to be + // converted to elementwise ops. + if (genericOp.getNumDpsInputs() <= 1) return failure(); + + for (auto iteratorType : genericOp.getIteratorTypesArray()) { + if (iteratorType != utils::IteratorType::parallel) { + return failure(); + } + } + + auto indexingMaps = genericOp.getIndexingMapsArray(); + bool madeChanges = false; + SmallVector newInputs; + SmallVector newMaps; + + int64_t numDims = genericOp.getNumLoops(); + + for (int64_t i = 0; i < genericOp.getNumDpsInputs(); ++i) { + OpOperand* operand = genericOp.getDpsInputOperand(i); + AffineMap map = indexingMaps[i]; + Value value = operand->get(); + + if (map.isIdentity()) { + newInputs.push_back(value); + newMaps.push_back(map); + continue; + } + + if (map.getNumResults() < numDims) { + madeChanges = true; + auto materializedValue = materializeBroadcastForOperand( + rewriter, genericOp, value, map, numDims); + if (failed(materializedValue)) { + return failure(); + } + newInputs.push_back(*materializedValue); + newMaps.push_back(rewriter.getMultiDimIdentityMap(numDims)); + } else { + newInputs.push_back(value); + newMaps.push_back(map); + } + } + + if (!madeChanges) return failure(); + + for (int64_t i = 0; i < genericOp.getNumDpsInits(); ++i) { + newMaps.push_back(indexingMaps[genericOp.getNumDpsInputs() + i]); + } + + auto newGenericOp = linalg::GenericOp::create( + rewriter, genericOp.getLoc(), genericOp.getResultTypes(), newInputs, + genericOp.getDpsInits(), newMaps, genericOp.getIteratorTypesArray()); + + rewriter.inlineRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + rewriter.replaceOp(genericOp, newGenericOp.getResults()); + + return success(); + } + + private: + LogicalResult tryCollapseUnitDims(PatternRewriter& rewriter, Location loc, + Value& value, AffineMap& map, + int64_t targetRank) const { + auto inputType = cast(value.getType()); + if (inputType.getRank() <= targetRank) return success(); + + SmallVector dimsToDrop; + for (unsigned j = 0; j < map.getNumResults(); ++j) { + auto expr = map.getResult(j); + if (auto constExpr = dyn_cast(expr)) { + if (inputType.getShape()[j] == 1) { + dimsToDrop.push_back(j); + } + } + } + + int64_t newRank = inputType.getRank() - dimsToDrop.size(); + if (dimsToDrop.empty() || newRank > targetRank) { + return failure(); + } + + auto reassociation = + getReassociationForReshapeAtDim(inputType.getRank(), dimsToDrop); + + SmallVector targetShape; + for (int64_t k = 0; k < inputType.getRank(); ++k) { + if (!llvm::is_contained(dimsToDrop, k)) { + targetShape.push_back(inputType.getShape()[k]); + } + } + + auto collapsedType = + RankedTensorType::get(targetShape, inputType.getElementType()); + + auto collapseOp = rewriter.create( + loc, collapsedType, value, reassociation); + + value = collapseOp.getResult(); + map = map.dropResults(dimsToDrop); + return success(); + } + + FailureOr materializeBroadcastForOperand(PatternRewriter& rewriter, + linalg::GenericOp genericOp, + Value value, AffineMap map, + int64_t numDims) const { + SmallVector addedDims = getBroadcastDimensions(map, numDims); + int64_t targetRank = numDims - addedDims.size(); + + if (failed(tryCollapseUnitDims(rewriter, genericOp.getLoc(), value, map, + targetRank))) { + return failure(); + } + + auto refOutput = genericOp.getDpsInitOperand(0)->get(); + auto refOutputType = cast(refOutput.getType()); + + auto emptyOp = tensor::EmptyOp::create(rewriter, genericOp.getLoc(), + refOutputType.getShape(), + refOutputType.getElementType()); + + auto broadcastOp = linalg::BroadcastOp::create( + rewriter, genericOp.getLoc(), value, emptyOp.getResult(), addedDims); + + return broadcastOp.getResults()[0]; + } +}; + +struct DropCfAssertInLinalg : public OpRewritePattern { + public: + DropCfAssertInLinalg(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter& rewriter) const override { + bool madeChanges = false; + auto* body = genericOp.getBody(); + for (auto& op : llvm::make_early_inc_range(body->getOperations())) { + if (auto assertOp = dyn_cast(op)) { + rewriter.eraseOp(assertOp); + madeChanges = true; + } + } + if (madeChanges) return success(); + return failure(); + } +}; + struct LinalgCanonicalizations : public impl::LinalgCanonicalizationsBase { void runOnOperation() override { @@ -740,7 +927,8 @@ struct LinalgCanonicalizations FoldConstantBroadcast, FoldBroadcastExtractSlice, LinalgMapToElementwise, LinalgGenericToElementwise, BroadcastToExpandShape, RewriteTransposedVecmat, - RewriteTransposedMatvec, RewriteAvgPoolAsConv2D>(context); + RewriteTransposedMatvec, RewriteAvgPoolAsConv2D, + MaterializeBroadcasts, DropCfAssertInLinalg>(context); // Run pattern matching and conversion // TODO (#1221): Investigate whether folding (default: on) can be skipped diff --git a/lib/Transforms/LinalgFuseLinearOps/BUILD b/lib/Transforms/LinalgFuseLinearOps/BUILD new file mode 100644 index 0000000000..ab5b1836ca --- /dev/null +++ b/lib/Transforms/LinalgFuseLinearOps/BUILD @@ -0,0 +1,33 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "LinalgFuseLinearOps", + srcs = ["LinalgFuseLinearOps.cpp"], + hdrs = [ + "LinalgFuseLinearOps.h", + ], + deps = [ + ":pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DestinationStyleOpInterface", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + ], +) + +add_heir_transforms( + generated_target_name = "pass_inc_gen", + pass_name = "LinalgFuseLinearOps", +) diff --git a/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp new file mode 100644 index 0000000000..a3c85f979b --- /dev/null +++ b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp @@ -0,0 +1,261 @@ +#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h" + +#include +#include +#include +#include + +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_LINALGFUSELINEAROPS +#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h.inc" + +namespace { + +template +LogicalResult findLinearOpAndOperand(OpTy op, Operation*& linearOp, + Value& rawOperand) { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + auto defOp = lhs.getDefiningOp(); + Value other = rhs; + if (!defOp || !isa(defOp)) { + defOp = rhs.getDefiningOp(); + other = lhs; + } + if (!defOp || !isa(defOp)) return failure(); + + linearOp = defOp; + rawOperand = other; + if (Operation* broadcastOp = other.getDefiningOp()) { + if (broadcastOp->getName().getStringRef() == "linalg.broadcast") { + rawOperand = broadcastOp->getOperand(0); + } + } + return success(); +} + +template +LogicalResult fuseScaleOrDivIntoLinearOp(PatternRewriter& rewriter, OpTy op) { + Operation* linearOp = nullptr; + Value scale_val; + if (failed(findLinearOpAndOperand(op, linearOp, scale_val))) return failure(); + + Value weights; + int64_t weightOperandIdx = -1; + int64_t matchDim = -1; + + llvm::TypeSwitch(linearOp) + .template Case([&](auto op) { + weights = op.getOperand(1); + weightOperandIdx = 1; + }) + .template Case([&](auto op) { + weights = op.getOperand(0); + weightOperandIdx = 0; + }) + .template Case([&](auto op) { + weights = op.getOperand(1); + weightOperandIdx = 1; + matchDim = 0; + }) + .template Case([&](auto op) { + weights = op.getOperand(1); + weightOperandIdx = 1; + matchDim = 3; + }) + .template Case([&](auto op) { + weights = op.getOperand(1); + weightOperandIdx = 1; + matchDim = 2; + }) + .Default([](auto) {}); + + if (!weights) return failure(); + + auto weightsType = cast(weights.getType()); + auto scaleValType = cast(scale_val.getType()); + + if (scaleValType.getRank() != 1) return failure(); + + // Determine which dimension of the weight matrix we need to broadcast the + // scalar along. + if (matchDim == -1) { + for (int i = 0; i < weightsType.getRank(); ++i) { + if (weightsType.getDimSize(i) == scaleValType.getDimSize(0)) { + matchDim = i; + break; + } + } + } + if (matchDim == -1) return failure(); + + SmallVector addedDims; + for (int i = 0; i < weightsType.getRank(); ++i) { + if (i != matchDim) { + addedDims.push_back(i); + } + } + + auto emptyOp = tensor::EmptyOp::create(rewriter, linearOp->getLoc(), + weightsType.getShape(), + weightsType.getElementType()); + + auto broadcastOp = linalg::BroadcastOp::create( + rewriter, linearOp->getLoc(), scale_val, emptyOp.getResult(), addedDims); + + auto scaledWeights = + OpTy::create(rewriter, op.getLoc(), weights, broadcastOp.getResults()[0]); + + IRMapping bvm; + Operation* newLinearOp = rewriter.clone(*linearOp, bvm); + newLinearOp->setOperand(weightOperandIdx, scaledWeights); + + rewriter.replaceOp(op, newLinearOp->getResults()); + return success(); +} + +template +LogicalResult fuseAddOrSubIntoLinearOp(PatternRewriter& rewriter, OpTy op) { + Operation* linearOp = nullptr; + Value addend; + if (failed(findLinearOpAndOperand(op, linearOp, addend))) return failure(); + + auto destStyleOp = dyn_cast(linearOp); + if (!destStyleOp) return failure(); + + if (destStyleOp.getNumDpsInits() != 1) return failure(); + + auto outputType = + cast(destStyleOp.getDpsInitOperand(0)->get().getType()); + auto addendType = cast(addend.getType()); + + if (addendType.getRank() != 1) return failure(); + + int64_t matchDimAddend = -1; + for (int i = 0; i < outputType.getRank(); ++i) { + if (outputType.getDimSize(i) == addendType.getDimSize(0)) { + matchDimAddend = i; + break; + } + } + if (matchDimAddend == -1) return failure(); + + SmallVector addedDimsAddend; + for (int i = 0; i < outputType.getRank(); ++i) { + if (i != matchDimAddend) { + addedDimsAddend.push_back(i); + } + } + + auto emptyOutputOp = tensor::EmptyOp::create(rewriter, linearOp->getLoc(), + outputType.getShape(), + outputType.getElementType()); + + auto broadcastAddendOp = + linalg::BroadcastOp::create(rewriter, linearOp->getLoc(), addend, + emptyOutputOp.getResult(), addedDimsAddend); + + Value existingOuts = destStyleOp.getDpsInitOperand(0)->get(); + Value newOuts; + if (existingOuts.getDefiningOp()) { + if constexpr (std::is_same_v) { + newOuts = arith::NegFOp::create(rewriter, op.getLoc(), + broadcastAddendOp.getResults()[0]); + } else { + newOuts = broadcastAddendOp.getResults()[0]; + } + } else { + newOuts = OpTy::create(rewriter, op.getLoc(), existingOuts, + broadcastAddendOp.getResults()[0]); + } + + IRMapping bvm; + Operation* newLinearOp = rewriter.clone(*linearOp, bvm); + unsigned initOperandIdx = + destStyleOp.getDpsInitOperand(0)->getOperandNumber(); + newLinearOp->setOperand(initOperandIdx, newOuts); + + rewriter.replaceOp(op, newLinearOp->getResults()); + return success(); +} + +struct FuseScaleIntoLinearOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::MulFOp op, + PatternRewriter& rewriter) const override { + return fuseScaleOrDivIntoLinearOp(rewriter, op); + } +}; + +struct FuseDivIntoLinearOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::DivFOp op, + PatternRewriter& rewriter) const override { + return fuseScaleOrDivIntoLinearOp(rewriter, op); + } +}; + +struct FuseAddIntoLinearOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter& rewriter) const override { + return fuseAddOrSubIntoLinearOp(rewriter, op); + } +}; + +struct FuseSubIntoLinearOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::SubFOp op, + PatternRewriter& rewriter) const override { + return fuseAddOrSubIntoLinearOp(rewriter, op); + } +}; + +} // namespace + +struct LinalgFuseLinearOps + : public impl::LinalgFuseLinearOpsBase { + void runOnOperation() override { + MLIRContext* context = &getContext(); + auto module = getOperation(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createLinalgFuseLinearOpsPass() { + return std::make_unique(); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h new file mode 100644 index 0000000000..6b4ad798cc --- /dev/null +++ b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h @@ -0,0 +1,20 @@ +#ifndef LIB_TRANSFORMS_LINALGFUSELINEAROPS_LINALGFUSELINEAROPS_H_ +#define LIB_TRANSFORMS_LINALGFUSELINEAROPS_LINALGFUSELINEAROPS_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +std::unique_ptr createLinalgFuseLinearOpsPass(); + +#define GEN_PASS_DECL +#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // LIB_TRANSFORMS_LINALGFUSELINEAROPS_LINALGFUSELINEAROPS_H_ diff --git a/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.td b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.td new file mode 100644 index 0000000000..84b13e2715 --- /dev/null +++ b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.td @@ -0,0 +1,31 @@ +#ifndef LIB_TRANSFORMS_LINALGFUSELINEAROPS_LINALGFUSELINEAROPS_TD_ +#define LIB_TRANSFORMS_LINALGFUSELINEAROPS_LINALGFUSELINEAROPS_TD_ + +include "mlir/Pass/PassBase.td" + +def LinalgFuseLinearOps : Pass<"linalg-fuse-linear-ops", "ModuleOp"> { + let summary = "Fuse arith ops into preceding linear operations."; + let description = [{ + Fuses elementary arithmetic operations into preceding linalg operations when possible. + + In particular, this pass fuses: + + - arith.addf/subf into the `outs` of a preceding linalg op that supports + DestinationStyleOpInterface. + - arith.mulf/divf into the weights of an opt-in list of preceding linalg + ops. For convolutions, only fusion by broadcasting along the channel + dimension is supported. + + In the context of a neural network, this pass generalizes the task of + folding a BatchNorm layer into a preceding dense or convolutional layer. + + (* example filepath=tests/Transforms/linalg_fuse_linear_ops/doctest.mlir *) + }]; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + "mlir::arith::ArithDialect", + "mlir::tensor::TensorDialect" + ]; +} + +#endif // LIB_TRANSFORMS_LINALGFUSELINEAROPS_LINALGFUSELINEAROPS_TD_ diff --git a/tests/Examples/common/batchnorm1d.mlir b/tests/Examples/common/batchnorm1d.mlir new file mode 100644 index 0000000000..352c2a2053 --- /dev/null +++ b/tests/Examples/common/batchnorm1d.mlir @@ -0,0 +1,66 @@ +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, 0)> +module { + func.func @batchnorm1d(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor<1x3x16xf32> {secret.secret}) -> tensor<1x3x16xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %cst_1 = arith.constant dense_resource : tensor<3xf32> + %cst_2 = arith.constant dense_resource : tensor<3xf32> + %cst_3 = arith.constant 1.000000e-05 : f64 + %0 = tensor.empty() : tensor<3xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : tensor<3xf32>) outs(%0 : tensor<3xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = arith.truncf %cst_3 : f64 to f32 + %10 = arith.addf %in, %9 : f32 + linalg.yield %10 : f32 + } -> tensor<3xf32> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%1 : tensor<3xf32>) outs(%0 : tensor<3xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = math.sqrt %in : f32 + linalg.yield %9 : f32 + } -> tensor<3xf32> + %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%2 : tensor<3xf32>) outs(%0 : tensor<3xf32>) { + ^bb0(%in: f32, %out: f32): + %9 = arith.cmpf one, %in, %cst : f32 + cf.assert %9, "unimplemented: tensor with zero element" + %10 = arith.divf %cst_0, %in : f32 + linalg.yield %10 : f32 + } -> tensor<3xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1]] output_shape [3, 1] : tensor<3xf32> into tensor<3x1xf32> + %expanded_4 = tensor.expand_shape %3 [[0, 1]] output_shape [3, 1] : tensor<3xf32> into tensor<3x1xf32> + %4 = tensor.empty() : tensor<1x3x16xf32> + %5 = linalg.generic {indexing_maps = [#map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg3, %expanded : tensor<1x3x16xf32>, tensor<3x1xf32>) outs(%4 : tensor<1x3x16xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %9 = arith.subf %in, %in_7 : f32 + linalg.yield %9 : f32 + } -> tensor<1x3x16xf32> + %6 = linalg.generic {indexing_maps = [#map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %expanded_4 : tensor<1x3x16xf32>, tensor<3x1xf32>) outs(%4 : tensor<1x3x16xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %9 = arith.mulf %in, %in_7 : f32 + linalg.yield %9 : f32 + } -> tensor<1x3x16xf32> + %expanded_5 = tensor.expand_shape %cst_2 [[0, 1]] output_shape [3, 1] : tensor<3xf32> into tensor<3x1xf32> + %7 = linalg.generic {indexing_maps = [#map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %expanded_5 : tensor<1x3x16xf32>, tensor<3x1xf32>) outs(%4 : tensor<1x3x16xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %9 = arith.mulf %in, %in_7 : f32 + linalg.yield %9 : f32 + } -> tensor<1x3x16xf32> + %expanded_6 = tensor.expand_shape %cst_1 [[0, 1]] output_shape [3, 1] : tensor<3xf32> into tensor<3x1xf32> + %8 = linalg.generic {indexing_maps = [#map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %expanded_6 : tensor<1x3x16xf32>, tensor<3x1xf32>) outs(%4 : tensor<1x3x16xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %9 = arith.addf %in, %in_7 : f32 + linalg.yield %9 : f32 + } -> tensor<1x3x16xf32> + return %8 : tensor<1x3x16xf32> + } +} + +{-# + dialect_resources: { + builtin: { + torch_tensor_3_torch.float32_1: "0x04000000000000000000000000000000", + torch_tensor_3_torch.float32: "0x040000000000803F0000803F0000803F" + } + } +#-} diff --git a/tests/Examples/lattigo/ckks/batchnorm1d/BUILD b/tests/Examples/lattigo/ckks/batchnorm1d/BUILD new file mode 100644 index 0000000000..22a7bd1390 --- /dev/null +++ b/tests/Examples/lattigo/ckks/batchnorm1d/BUILD @@ -0,0 +1,22 @@ +load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib") +load("@rules_go//go:def.bzl", "go_test") + +package(default_applicable_licenses = ["@heir//:license"]) + +heir_lattigo_lib( + name = "batchnorm1d", + go_library_name = "batchnorm1d", + heir_opt_flags = [ + "--annotate-module=backend=lattigo scheme=ckks", + "--linalg-canonicalizations", + "--torch-linalg-to-ckks=ciphertext-degree=2048 split-preprocessing=1", + "--scheme-to-lattigo", + ], + mlir_src = "@heir//tests/Examples/common:batchnorm1d.mlir", +) + +go_test( + name = "batchnorm1d_test", + srcs = ["batchnorm1d_test.go"], + embed = [":batchnorm1d"], +) diff --git a/tests/Examples/lattigo/ckks/batchnorm1d/batchnorm1d_test.go b/tests/Examples/lattigo/ckks/batchnorm1d/batchnorm1d_test.go new file mode 100644 index 0000000000..cb65cd66e6 --- /dev/null +++ b/tests/Examples/lattigo/ckks/batchnorm1d/batchnorm1d_test.go @@ -0,0 +1,25 @@ +package batchnorm1d + +import ( + "testing" +) + +func TestBatchNorm1d(t *testing.T) { + evaluator, params, ecd, enc, dec := batchnorm1d__configure() + + arg0 := []float32{0.1, 0.2, 0.3} // mean + arg1 := []float32{0.4, 0.5, 0.6} // var + arg2 := []int64{0} + arg3 := make([]float32, 48) // data + for i := 0; i < 48; i++ { + arg3[i] = float32(i) * 0.1 + } + + ct3 := batchnorm1d__encrypt__arg3(evaluator, params, ecd, enc, arg3) + + resultCt := batchnorm1d(evaluator, params, ecd, arg0, arg1, arg2, ct3) + + result := batchnorm1d__decrypt__result0(evaluator, params, ecd, dec, resultCt) + + t.Logf("Result size: %d", len(result)) +} diff --git a/tests/Examples/lattigo/ckks/matvec_batchnorm/BUILD b/tests/Examples/lattigo/ckks/matvec_batchnorm/BUILD new file mode 100644 index 0000000000..407abdc4df --- /dev/null +++ b/tests/Examples/lattigo/ckks/matvec_batchnorm/BUILD @@ -0,0 +1,23 @@ +load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib") +load("@rules_go//go:def.bzl", "go_test") + +package(default_applicable_licenses = ["@heir//:license"]) + +heir_lattigo_lib( + name = "matvec_batchnorm", + go_library_name = "matvecbatchnorm", + heir_opt_flags = [ + "--annotate-module=backend=lattigo scheme=ckks", + "--linalg-canonicalizations", + "--torch-linalg-to-ckks=ciphertext-degree=2048", + "--scheme-to-lattigo", + ], + mlir_src = "matvec_batchnorm.mlir", + split_preprocessing = False, +) + +go_test( + name = "matvec_batchnorm_test", + srcs = ["matvec_batchnorm_test.go"], + embed = [":matvecbatchnorm"], +) diff --git a/tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm.mlir b/tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm.mlir new file mode 100644 index 0000000000..d64131b565 --- /dev/null +++ b/tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm.mlir @@ -0,0 +1,32 @@ +module { + func.func @matvec_batchnorm(%arg0: tensor<4xf32> {secret.secret}, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]> : tensor<4x4xf32> + %cst_0 = tensor.empty() : tensor<4xf32> + + // Matvec + %0 = linalg.matvec ins(%cst, %arg0 : tensor<4x4xf32>, tensor<4xf32>) outs(%cst_0 : tensor<4xf32>) -> tensor<4xf32> + + // Scale + %1 = tensor.empty() : tensor<4xf32> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%0, %arg1 : tensor<4xf32>, tensor<4xf32>) outs(%1 : tensor<4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + linalg.yield %3 : f32 + } -> tensor<4xf32> + + // Shift + %3 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%2, %arg2 : tensor<4xf32>, tensor<4xf32>) outs(%1 : tensor<4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<4xf32> + + return %3 : tensor<4xf32> + } +} diff --git a/tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm_test.go b/tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm_test.go new file mode 100644 index 0000000000..b32908f247 --- /dev/null +++ b/tests/Examples/lattigo/ckks/matvec_batchnorm/matvec_batchnorm_test.go @@ -0,0 +1,25 @@ +package matvecbatchnorm + +import ( + "fmt" + "testing" +) + +func TestMatvecBatchNorm(t *testing.T) { + evaluator, params, ecd, enc, dec := matvec_batchnorm__configure() + arg0 := []float32{1, 2, 3, 4} + arg1 := []float32{2, 2, 2, 2} // scale + arg2 := []float32{1, 1, 1, 1} // shift + + ct0 := matvec_batchnorm__encrypt__arg0(evaluator, params, ecd, enc, arg0) + resultCt := matvec_batchnorm(evaluator, params, ecd, ct0, arg1, arg2) + result := matvec_batchnorm__decrypt__result0(evaluator, params, ecd, dec, resultCt) + fmt.Println(result) + + expected := []float32{61, 141, 221, 301} + for i, v := range result { + if v != expected[i] { + t.Errorf("Expected %v, got %v at index %d", expected[i], v, i) + } + } +} diff --git a/tests/Transforms/linalg_canonicalizations/batchnorm1d.mlir b/tests/Transforms/linalg_canonicalizations/batchnorm1d.mlir new file mode 100644 index 0000000000..f675337329 --- /dev/null +++ b/tests/Transforms/linalg_canonicalizations/batchnorm1d.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --linalg-canonicalizations --split-input-file %s | FileCheck %s + +// In this test, the axes of arg0 (1x3x16) do not align with the axes of arg1 +// (3x1) and so the pattern identifies that it can collapse the shape of arg1 +// before re-broadcasting it to match arg0. +#map_3d = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map_2d = affine_map<(d0, d1, d2) -> (d1, 0)> + +// CHECK: @test_materialize_broadcast +func.func @test_materialize_broadcast( + %arg0: tensor<1x3x16xf32>, + %arg1: tensor<3x1xf32> +) -> tensor<1x3x16xf32> { + %empty = tensor.empty() : tensor<1x3x16xf32> + + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG1:.*]] {{\[\[}}0, 1{{\]\]}} : tensor<3x1xf32> into tensor<3xf32> + // CHECK: %[[BROADCAST:.*]] = linalg.broadcast ins(%[[COLLAPSED]] : tensor<3xf32>) outs(%{{.*}} : tensor<1x3x16xf32>) dimensions = [0, 2] + // CHECK: %[[SUB:.*]] = arith.subf %[[ARG0:.*]], %[[BROADCAST]] : tensor<1x3x16xf32> + // CHECK: return %[[SUB]] : tensor<1x3x16xf32> + %result = linalg.generic { + indexing_maps = [#map_3d, #map_2d, #map_3d], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<1x3x16xf32>, tensor<3x1xf32>) outs(%empty : tensor<1x3x16xf32>) { + ^bb0(%in: f32, %in_elt: f32, %out: f32): + %sub = arith.subf %in, %in_elt : f32 + linalg.yield %sub : f32 + } -> tensor<1x3x16xf32> + + return %result : tensor<1x3x16xf32> +} + +// ----- diff --git a/tests/Transforms/linalg_canonicalizations/materialize_broadcast.mlir b/tests/Transforms/linalg_canonicalizations/materialize_broadcast.mlir new file mode 100644 index 0000000000..e9cee1bc86 --- /dev/null +++ b/tests/Transforms/linalg_canonicalizations/materialize_broadcast.mlir @@ -0,0 +1,75 @@ +// RUN: heir-opt --linalg-canonicalizations --split-input-file %s | FileCheck %s + +// CHECK: @basic_broadcast +func.func @basic_broadcast(%arg0: tensor<3xf32>, %arg1: tensor<3x16xf32>) -> tensor<3x16xf32> { + %cst = tensor.empty() : tensor<3x16xf32> + // CHECK: %[[EXPANDED:.*]] = linalg.broadcast ins(%arg0 : tensor<3xf32>) outs({{.*}}) dimensions = [1] + // CHECK: arith.addf {{.*}}, {{.*}} : tensor<3x16xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<3xf32>, tensor<3x16xf32>) outs(%cst : tensor<3x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.addf %in, %in_0 : f32 + linalg.yield %1 : f32 + } -> tensor<3x16xf32> + return %0 : tensor<3x16xf32> +} + +// ----- + +// CHECK: @multiple_broadcast +func.func @multiple_broadcast(%arg0: tensor<3xf32>, %arg1: tensor<16xf32>, %arg2: tensor<3x16xf32>) -> tensor<3x16xf32> { + %cst = tensor.empty() : tensor<3x16xf32> + // CHECK: %[[EXPANDED1:.*]] = linalg.broadcast ins(%arg0 : tensor<3xf32>) outs({{.*}}) dimensions = [1] + // CHECK: %[[EXPANDED2:.*]] = linalg.broadcast ins(%arg1 : tensor<16xf32>) outs({{.*}}) dimensions = [0] + // CHECK: %[[SUM:.*]] = arith.addf {{.*}}, {{.*}} + // CHECK: arith.addf {{.*}}, {{.*}} : tensor<3x16xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1, %arg2 : tensor<3xf32>, tensor<16xf32>, tensor<3x16xf32>) outs(%cst : tensor<3x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %1 = arith.addf %in, %in_0 : f32 + %2 = arith.addf %1, %in_1 : f32 + linalg.yield %2 : f32 + } -> tensor<3x16xf32> + return %0 : tensor<3x16xf32> +} + +// ----- + +// CHECK: @single_input_no_broadcast +func.func @single_input_no_broadcast(%arg0: tensor<3xf32>) -> tensor<3x16xf32> { + %cst = tensor.empty() : tensor<3x16xf32> + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#map, #map1] + // CHECK-NOT: arith.addf + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0 : tensor<3xf32>) outs(%cst : tensor<3x16xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<3x16xf32> + return %0 : tensor<3x16xf32> +} + +// ----- + +// CHECK: @reduction_no_broadcast +func.func @reduction_no_broadcast(%arg0: tensor<3x16xf32>) -> tensor<3xf32> { + %cst = tensor.empty() : tensor<3xf32> + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#map, #map1] + // CHECK-SAME: iterator_types = ["parallel", "reduction"] + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } ins(%arg0 : tensor<3x16xf32>) outs(%cst : tensor<3xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<3xf32> + return %0 : tensor<3xf32> +} diff --git a/tests/Transforms/linalg_fuse_linear_ops/BUILD b/tests/Transforms/linalg_fuse_linear_ops/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Transforms/linalg_fuse_linear_ops/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Transforms/linalg_fuse_linear_ops/doctest.mlir b/tests/Transforms/linalg_fuse_linear_ops/doctest.mlir new file mode 100644 index 0000000000..3bd3afb186 --- /dev/null +++ b/tests/Transforms/linalg_fuse_linear_ops/doctest.mlir @@ -0,0 +1,21 @@ +// RUN: heir-opt --linalg-fuse-linear-ops --split-input-file %s | FileCheck %s + +// CHECK: func.func @fuse_matmul +// CHECK: %[[EMPTY_W:.*]] = tensor.empty() : tensor<3x4xf32> +// CHECK: %[[BROADCAST_W:.*]] = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%[[EMPTY_W]] : tensor<3x4xf32>) dimensions = [0] +// CHECK: %[[SCALED_W:.*]] = arith.mulf %arg1, %[[BROADCAST_W]] : tensor<3x4xf32> +// CHECK: %[[EMPTY_OUT:.*]] = tensor.empty() : tensor<2x4xf32> +// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%[[EMPTY_OUT]] : tensor<2x4xf32>) dimensions = [0] +// CHECK: %[[RESULT:.*]] = linalg.matmul ins(%arg0, %[[SCALED_W]] : tensor<2x3xf32>, tensor<3x4xf32>) outs(%[[BROADCAST_OUT]] : tensor<2x4xf32>) +// CHECK: return %[[RESULT]] +func.func @fuse_matmul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = tensor.empty() : tensor<2x4xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%2 : tensor<2x4xf32>) dimensions = [0] + %3 = arith.mulf %1, %broadcasted : tensor<2x4xf32> + %4 = tensor.empty() : tensor<2x4xf32> + %broadcasted_0 = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%4 : tensor<2x4xf32>) dimensions = [0] + %5 = arith.addf %3, %broadcasted_0 : tensor<2x4xf32> + return %5 : tensor<2x4xf32> +} diff --git a/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir b/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir new file mode 100644 index 0000000000..ccf90ee165 --- /dev/null +++ b/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir @@ -0,0 +1,101 @@ +// RUN: heir-opt --linalg-fuse-linear-ops --split-input-file %s | FileCheck %s + +// CHECK: func.func @fuse_matmul +// CHECK: %[[EMPTY_W:.*]] = tensor.empty() : tensor<3x4xf32> +// CHECK: %[[BROADCAST_W:.*]] = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%[[EMPTY_W]] : tensor<3x4xf32>) dimensions = [0] +// CHECK: %[[SCALED_W:.*]] = arith.mulf %arg1, %[[BROADCAST_W]] : tensor<3x4xf32> +// CHECK: %[[EMPTY_OUT:.*]] = tensor.empty() : tensor<2x4xf32> +// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%[[EMPTY_OUT]] : tensor<2x4xf32>) dimensions = [0] +// CHECK: %[[RESULT:.*]] = linalg.matmul ins(%arg0, %[[SCALED_W]] : tensor<2x3xf32>, tensor<3x4xf32>) outs(%[[BROADCAST_OUT]] : tensor<2x4xf32>) +// CHECK: return %[[RESULT]] +func.func @fuse_matmul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = tensor.empty() : tensor<2x4xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%2 : tensor<2x4xf32>) dimensions = [0] + %3 = arith.mulf %1, %broadcasted : tensor<2x4xf32> + %4 = tensor.empty() : tensor<2x4xf32> + %broadcasted_0 = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%4 : tensor<2x4xf32>) dimensions = [0] + %5 = arith.addf %3, %broadcasted_0 : tensor<2x4xf32> + return %5 : tensor<2x4xf32> +} + +// ----- + +// CHECK: func.func @fuse_matvec +// CHECK: %[[EMPTY_W:.*]] = tensor.empty() : tensor<2x3xf32> +// CHECK: %[[BROADCAST_W:.*]] = linalg.broadcast ins(%arg2 : tensor<2xf32>) outs(%[[EMPTY_W]] : tensor<2x3xf32>) dimensions = [1] +// CHECK: %[[SCALED_W:.*]] = arith.mulf %arg0, %[[BROADCAST_W]] : tensor<2x3xf32> +// CHECK: %[[EMPTY_OUT:.*]] = tensor.empty() : tensor<2xf32> +// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast ins(%arg3 : tensor<2xf32>) outs(%[[EMPTY_OUT]] : tensor<2xf32>) dimensions = [] +// CHECK: %[[RESULT:.*]] = linalg.matvec ins(%[[SCALED_W]], %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%[[BROADCAST_OUT]] : tensor<2xf32>) +// CHECK: return %[[RESULT]] +func.func @fuse_matvec(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> { + %0 = tensor.empty() : tensor<2xf32> + %1 = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%0 : tensor<2xf32>) -> tensor<2xf32> + %2 = arith.mulf %1, %arg2 : tensor<2xf32> + %3 = arith.addf %2, %arg3 : tensor<2xf32> + return %3 : tensor<2xf32> +} + +// ----- + +// CHECK: func.func @fuse_vecmat +// CHECK: %[[EMPTY_W:.*]] = tensor.empty() : tensor<3x4xf32> +// CHECK: %[[BROADCAST_W:.*]] = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%[[EMPTY_W]] : tensor<3x4xf32>) dimensions = [0] +// CHECK: %[[SCALED_W:.*]] = arith.mulf %arg1, %[[BROADCAST_W]] : tensor<3x4xf32> +// CHECK: %[[EMPTY_OUT:.*]] = tensor.empty() : tensor<4xf32> +// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%[[EMPTY_OUT]] : tensor<4xf32>) dimensions = [] +// CHECK: %[[RESULT:.*]] = linalg.vecmat ins(%arg0, %[[SCALED_W]] : tensor<3xf32>, tensor<3x4xf32>) outs(%[[BROADCAST_OUT]] : tensor<4xf32>) +// CHECK: return %[[RESULT]] +func.func @fuse_vecmat(%arg0: tensor<3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>) -> tensor<4xf32> { + %0 = tensor.empty() : tensor<4xf32> + %1 = linalg.vecmat ins(%arg0, %arg1 : tensor<3xf32>, tensor<3x4xf32>) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + %2 = arith.mulf %1, %arg2 : tensor<4xf32> + %3 = arith.addf %2, %arg3 : tensor<4xf32> + return %3 : tensor<4xf32> +} + +// ----- + +// CHECK: func.func @fuse_conv_2d +// CHECK: %[[EMPTY_W:.*]] = tensor.empty() : tensor<8x3x3x3xf32> +// CHECK: %[[BROADCAST_W:.*]] = linalg.broadcast ins(%arg2 : tensor<8xf32>) outs(%[[EMPTY_W]] : tensor<8x3x3x3xf32>) dimensions = [1, 2, 3] +// CHECK: %[[SCALED_W:.*]] = arith.mulf %arg1, %[[BROADCAST_W]] : tensor<8x3x3x3xf32> +// CHECK: %[[EMPTY_OUT:.*]] = tensor.empty() : tensor<1x8x14x14xf32> +// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast ins(%arg3 : tensor<8xf32>) outs(%[[EMPTY_OUT]] : tensor<1x8x14x14xf32>) dimensions = [0, 2, 3] +// CHECK: %[[RESULT:.*]] = linalg.conv_2d_nchw_fchw ins(%arg0, %[[SCALED_W]] : tensor<1x3x16x16xf32>, tensor<8x3x3x3xf32>) outs(%[[BROADCAST_OUT]] : tensor<1x8x14x14xf32>) +// CHECK: return %[[RESULT]] +func.func @fuse_conv_2d(%arg0: tensor<1x3x16x16xf32>, %arg1: tensor<8x3x3x3xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<1x8x14x14xf32> { + %0 = tensor.empty() : tensor<1x8x14x14xf32> + %1 = linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : tensor<1x3x16x16xf32>, tensor<8x3x3x3xf32>) outs(%0 : tensor<1x8x14x14xf32>) -> tensor<1x8x14x14xf32> + %2 = tensor.empty() : tensor<1x8x14x14xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<8xf32>) outs(%2 : tensor<1x8x14x14xf32>) dimensions = [0, 2, 3] + %3 = arith.mulf %1, %broadcasted : tensor<1x8x14x14xf32> + %4 = tensor.empty() : tensor<1x8x14x14xf32> + %broadcasted_0 = linalg.broadcast ins(%arg3 : tensor<8xf32>) outs(%4 : tensor<1x8x14x14xf32>) dimensions = [0, 2, 3] + %5 = arith.addf %3, %broadcasted_0 : tensor<1x8x14x14xf32> + return %5 : tensor<1x8x14x14xf32> +} + +// ----- + +// CHECK: func.func @fuse_matmul_with_bias +// CHECK: %[[EMPTY_W:.*]] = tensor.empty() : tensor<3x4xf32> +// CHECK: %[[BROADCAST_W:.*]] = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%[[EMPTY_W]] : tensor<3x4xf32>) dimensions = [0] +// CHECK: %[[SCALED_W:.*]] = arith.mulf %arg1, %[[BROADCAST_W]] : tensor<3x4xf32> +// CHECK: %[[EMPTY_OUT:.*]] = tensor.empty() : tensor<2x4xf32> +// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%[[EMPTY_OUT]] : tensor<2x4xf32>) dimensions = [0] +// CHECK: %[[NEW_OUTS:.*]] = arith.addf %arg4, %[[BROADCAST_OUT]] : tensor<2x4xf32> +// CHECK: %[[RESULT:.*]] = linalg.matmul ins(%arg0, %[[SCALED_W]] : tensor<2x3xf32>, tensor<3x4xf32>) outs(%[[NEW_OUTS]] : tensor<2x4xf32>) +// CHECK: return %[[RESULT]] +func.func @fuse_matmul_with_bias(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>, %arg4: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%arg4 : tensor<2x4xf32>) -> tensor<2x4xf32> + %1 = tensor.empty() : tensor<2x4xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<4xf32>) outs(%1 : tensor<2x4xf32>) dimensions = [0] + %2 = arith.mulf %0, %broadcasted : tensor<2x4xf32> + %3 = tensor.empty() : tensor<2x4xf32> + %broadcasted_0 = linalg.broadcast ins(%arg3 : tensor<4xf32>) outs(%3 : tensor<2x4xf32>) dimensions = [0] + %4 = arith.addf %2, %broadcasted_0 : tensor<2x4xf32> + return %4 : tensor<2x4xf32> +} diff --git a/tools/BUILD b/tools/BUILD index 4e4f28f86b..e3d6cdb49b 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -158,6 +158,7 @@ cc_binary( "@heir//lib/Transforms/LayoutPropagation", "@heir//lib/Transforms/LayoutPropagation:InterfaceImpl", "@heir//lib/Transforms/LinalgCanonicalizations", + "@heir//lib/Transforms/LinalgFuseLinearOps", "@heir//lib/Transforms/LowerPolynomialEval", "@heir//lib/Transforms/LowerUnpack", "@heir//lib/Transforms/MemrefToArith:ExpandCopy", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index d35cd4115e..45937c3f75 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -104,6 +104,7 @@ #include "lib/Transforms/LayoutPropagation/InterfaceImpl.h" #include "lib/Transforms/LayoutPropagation/LayoutPropagation.h" #include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h" +#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h" #include "lib/Transforms/LowerPolynomialEval/LowerPolynomialEval.h" #include "lib/Transforms/LowerUnpack/LowerUnpack.h" #include "lib/Transforms/OperationBalancer/OperationBalancer.h" @@ -355,6 +356,7 @@ int main(int argc, char** argv) { registerLayoutPropagationPasses(); registerLayoutOptimizationPasses(); registerLinalgCanonicalizationsPasses(); + registerLinalgFuseLinearOpsPasses(); registerReductionCanonicalizationsPasses(); registerFoldConstantTensorsPasses(); registerLowerPolynomialEvalPasses();