Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "lib/Transforms/LayoutOptimization/LayoutOptimization.h"
#include "lib/Transforms/LayoutPropagation/LayoutPropagation.h"
#include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h"
#include "lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.h"
#include "lib/Transforms/OperationBalancer/OperationBalancer.h"
#include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
#include "lib/Transforms/PopulateScale/PopulateScale.h"
Expand Down Expand Up @@ -541,6 +542,7 @@ void linalgPreprocessingBuilder(OpPassManager& manager) {
manager.addPass(createInlineActivations());
manager.addPass(createActivationCanonicalizations());
manager.addPass(createLinalgCanonicalizations());
manager.addPass(createLinalgFuseLinearOpsPass());
manager.addPass(createDropUnitDims());
manager.addPass(createFoldConstantTensors());
manager.addPass(createCanonicalizerPass());
Expand Down
1 change: 1 addition & 0 deletions lib/Pipelines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ cc_library(
"@heir//lib/Transforms/LayoutOptimization",
"@heir//lib/Transforms/LayoutPropagation",
"@heir//lib/Transforms/LinalgCanonicalizations",
"@heir//lib/Transforms/LinalgFuseLinearOps",
"@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration",
"@heir//lib/Transforms/OperationBalancer",
"@heir//lib/Transforms/OptimizeRelinearization",
Expand Down
14 changes: 9 additions & 5 deletions lib/Target/Lattigo/LattigoEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ LogicalResult LattigoEmitter::translate(Operation& op) {
// Arith ops
.Case<arith::ConstantOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::IndexCastOp, arith::ExtFOp,
arith::RemSIOp, arith::AddIOp, arith::AddFOp, arith::AndIOp,
arith::SubIOp, arith::SubFOp, arith::MaxSIOp, arith::MinSIOp,
arith::MulIOp, arith::MulFOp, arith::DivSIOp, arith::DivFOp,
arith::NegFOp, arith::OrIOp, arith::XOrIOp, arith::CmpIOp,
arith::CmpFOp, arith::SelectOp>(
arith::TruncFOp, arith::RemSIOp, arith::AddIOp, arith::AddFOp,
arith::AndIOp, arith::SubIOp, arith::SubFOp, arith::MaxSIOp,
arith::MinSIOp, arith::MulIOp, arith::MulFOp, arith::DivSIOp,
arith::DivFOp, arith::NegFOp, arith::OrIOp, arith::XOrIOp,
arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(
[&](auto op) { return printOperation(op); })
// SCF ops
.Case<scf::IfOp, scf::ForOp, scf::YieldOp>(
Expand Down Expand Up @@ -709,6 +709,10 @@ LogicalResult LattigoEmitter::printOperation(arith::ExtFOp op) {
return typecast(op.getOperand(), op.getResult());
}

LogicalResult LattigoEmitter::printOperation(arith::TruncFOp op) {
return typecast(op.getOperand(), op.getResult());
}

LogicalResult LattigoEmitter::printOperation(arith::FloorDivSIOp op) {
imports.insert(std::string(kMathImport));

Expand Down
1 change: 1 addition & 0 deletions lib/Target/Lattigo/LattigoEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class LattigoEmitter {
LogicalResult printOperation(::mlir::arith::ConstantOp op);
LogicalResult printOperation(::mlir::arith::DivSIOp op);
LogicalResult printOperation(::mlir::arith::DivFOp op);
LogicalResult printOperation(::mlir::arith::TruncFOp op);
LogicalResult printOperation(::mlir::arith::ExtFOp op);
LogicalResult printOperation(::mlir::arith::ExtSIOp op);
LogicalResult printOperation(::mlir::arith::ExtUIOp op);
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/LinalgCanonicalizations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
Expand Down
192 changes: 190 additions & 2 deletions lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
#include <utility>

#include "lib/Utils/TensorUtils.h"
#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallBitVector.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project
#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -729,6 +732,190 @@ struct RewriteAvgPoolAsConv2D
}
};

static SmallVector<int64_t> getBroadcastDimensions(AffineMap map,
int64_t numDims) {
llvm::SmallDenseSet<unsigned> usedDims;
for (auto expr : map.getResults()) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
usedDims.insert(dimExpr.getPosition());
}
}
SmallVector<int64_t> addedDims;
for (int i = 0; i < numDims; ++i) {
if (usedDims.find(i) == usedDims.end()) {
addedDims.push_back(i);
}
}
return addedDims;
}

/// A rewrite pattern that materializes broadcasts for broadcasting operands in
/// linalg.generic ops with parallel iterators.
///
/// This pattern matches linalg.generic ops where all iterator types are
/// parallel, and at least one operand has a broadcasting indexing map (i.e.,
/// the map drops dimensions, mapping a larger iteration space to a smaller
/// operand space). It creates a linalg.broadcast op for each such operand to
/// materialize the broadcast, making the operand match the output shape.
/// This allows subsequent patterns (like LinalgGenericToElementwise) to convert
/// the op to elementwise operations.
struct MaterializeBroadcasts : public OpRewritePattern<linalg::GenericOp> {
public:
MaterializeBroadcasts(MLIRContext* context)
: OpRewritePattern<linalg::GenericOp>(context) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter& rewriter) const override {
// Only handle ops with multiple inputs to avoid infinite loops when
// materializing broadcasts. Single-input broadcast ops don't need to be
// converted to elementwise ops.
if (genericOp.getNumDpsInputs() <= 1) return failure();

for (auto iteratorType : genericOp.getIteratorTypesArray()) {
if (iteratorType != utils::IteratorType::parallel) {
return failure();
}
}

auto indexingMaps = genericOp.getIndexingMapsArray();
bool madeChanges = false;
SmallVector<Value> newInputs;
SmallVector<AffineMap> newMaps;

int64_t numDims = genericOp.getNumLoops();

for (int64_t i = 0; i < genericOp.getNumDpsInputs(); ++i) {
OpOperand* operand = genericOp.getDpsInputOperand(i);
AffineMap map = indexingMaps[i];
Value value = operand->get();

if (map.isIdentity()) {
newInputs.push_back(value);
newMaps.push_back(map);
continue;
}

if (map.getNumResults() < numDims) {
madeChanges = true;
auto materializedValue = materializeBroadcastForOperand(
rewriter, genericOp, value, map, numDims);
if (failed(materializedValue)) {
return failure();
}
newInputs.push_back(*materializedValue);
newMaps.push_back(rewriter.getMultiDimIdentityMap(numDims));
} else {
newInputs.push_back(value);
newMaps.push_back(map);
}
}

if (!madeChanges) return failure();

for (int64_t i = 0; i < genericOp.getNumDpsInits(); ++i) {
newMaps.push_back(indexingMaps[genericOp.getNumDpsInputs() + i]);
}

auto newGenericOp = linalg::GenericOp::create(
rewriter, genericOp.getLoc(), genericOp.getResultTypes(), newInputs,
genericOp.getDpsInits(), newMaps, genericOp.getIteratorTypesArray());

rewriter.inlineRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
newGenericOp.getRegion().begin());
rewriter.replaceOp(genericOp, newGenericOp.getResults());

return success();
}

private:
LogicalResult tryCollapseUnitDims(PatternRewriter& rewriter, Location loc,
Value& value, AffineMap& map,
int64_t targetRank) const {
auto inputType = cast<RankedTensorType>(value.getType());
if (inputType.getRank() <= targetRank) return success();

SmallVector<int64_t> dimsToDrop;
for (unsigned j = 0; j < map.getNumResults(); ++j) {
auto expr = map.getResult(j);
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
if (inputType.getShape()[j] == 1) {
dimsToDrop.push_back(j);
}
}
}

int64_t newRank = inputType.getRank() - dimsToDrop.size();
if (dimsToDrop.empty() || newRank > targetRank) {
return failure();
}

auto reassociation =
getReassociationForReshapeAtDim(inputType.getRank(), dimsToDrop);

SmallVector<int64_t> targetShape;
for (int64_t k = 0; k < inputType.getRank(); ++k) {
if (!llvm::is_contained(dimsToDrop, k)) {
targetShape.push_back(inputType.getShape()[k]);
}
}

auto collapsedType =
RankedTensorType::get(targetShape, inputType.getElementType());

auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, value, reassociation);

value = collapseOp.getResult();
map = map.dropResults(dimsToDrop);
return success();
}

FailureOr<Value> materializeBroadcastForOperand(PatternRewriter& rewriter,
linalg::GenericOp genericOp,
Value value, AffineMap map,
int64_t numDims) const {
SmallVector<int64_t> addedDims = getBroadcastDimensions(map, numDims);
int64_t targetRank = numDims - addedDims.size();

if (failed(tryCollapseUnitDims(rewriter, genericOp.getLoc(), value, map,
targetRank))) {
return failure();
}

auto refOutput = genericOp.getDpsInitOperand(0)->get();
auto refOutputType = cast<RankedTensorType>(refOutput.getType());

auto emptyOp = tensor::EmptyOp::create(rewriter, genericOp.getLoc(),
refOutputType.getShape(),
refOutputType.getElementType());

auto broadcastOp = linalg::BroadcastOp::create(
rewriter, genericOp.getLoc(), value, emptyOp.getResult(), addedDims);

return broadcastOp.getResults()[0];
}
};

struct DropCfAssertInLinalg : public OpRewritePattern<linalg::GenericOp> {
public:
DropCfAssertInLinalg(MLIRContext* context)
: OpRewritePattern<linalg::GenericOp>(context) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter& rewriter) const override {
bool madeChanges = false;
auto* body = genericOp.getBody();
for (auto& op : llvm::make_early_inc_range(body->getOperations())) {
if (auto assertOp = dyn_cast<cf::AssertOp>(op)) {
rewriter.eraseOp(assertOp);
madeChanges = true;
}
}
if (madeChanges) return success();
return failure();
}
};

struct LinalgCanonicalizations
: public impl::LinalgCanonicalizationsBase<LinalgCanonicalizations> {
void runOnOperation() override {
Expand All @@ -740,7 +927,8 @@ struct LinalgCanonicalizations
FoldConstantBroadcast, FoldBroadcastExtractSlice,
LinalgMapToElementwise, LinalgGenericToElementwise,
BroadcastToExpandShape, RewriteTransposedVecmat,
RewriteTransposedMatvec, RewriteAvgPoolAsConv2D>(context);
RewriteTransposedMatvec, RewriteAvgPoolAsConv2D,
MaterializeBroadcasts, DropCfAssertInLinalg>(context);

// Run pattern matching and conversion
// TODO (#1221): Investigate whether folding (default: on) can be skipped
Expand Down
33 changes: 33 additions & 0 deletions lib/Transforms/LinalgFuseLinearOps/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "LinalgFuseLinearOps",
srcs = ["LinalgFuseLinearOps.cpp"],
hdrs = [
"LinalgFuseLinearOps.h",
],
deps = [
":pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
],
)

add_heir_transforms(
generated_target_name = "pass_inc_gen",
pass_name = "LinalgFuseLinearOps",
)
Loading
Loading