Skip to content
Snippets Groups Projects
Commit 281ee429 authored by Alexander Belyaev's avatar Alexander Belyaev
Browse files

[mlir] Add a pass to distribute linalg::TiledLoopOp.

Differential Revision: https://reviews.llvm.org/D103194
parent 51d334a8
No related branches found
No related tags found
No related merge requests found
...@@ -860,6 +860,13 @@ void populateLinalgConvGeneralizationPatterns( ...@@ -860,6 +860,13 @@ void populateLinalgConvGeneralizationPatterns(
RewritePatternSet &patterns, RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter()); LinalgTransformationFilter filter = LinalgTransformationFilter());
/// Linalg distribution patterns
//
/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
void populateLinalgDistributeTiledLoopPattern(
RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
const LinalgTransformationFilter &marker);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op-specific patterns. // Op-specific patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
......
...@@ -184,6 +184,8 @@ struct ProcInfo { ...@@ -184,6 +184,8 @@ struct ProcInfo {
}; };
using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>( using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>; OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
using OneDimProcInfoCallBackFn =
std::function<ProcInfo(OpBuilder &b, Location loc)>;
/// Options that allow distribution of loops generated in Linalg transforms to /// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops. /// processors while generating the loops.
...@@ -201,6 +203,11 @@ struct LinalgLoopDistributionOptions { ...@@ -201,6 +203,11 @@ struct LinalgLoopDistributionOptions {
/// applied. If the vector is less than the number of `scf.parallel` loops /// applied. If the vector is less than the number of `scf.parallel` loops
/// generated, then no distribution is applied. /// generated, then no distribution is applied.
SmallVector<DistributionMethod, 0> distributionMethod = {}; SmallVector<DistributionMethod, 0> distributionMethod = {};
/// The map keyed by the distribution type that contains callback functions
/// that return the Values for processor ID (`procId`), and number of
/// processors (`nprocs`) used to execute the parallel loops.
DenseMap<StringRef, OneDimProcInfoCallBackFn> procInfoMap;
}; };
/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
......
...@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms ...@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
CodegenStrategy.cpp CodegenStrategy.cpp
ComprehensiveBufferize.cpp ComprehensiveBufferize.cpp
Detensorize.cpp Detensorize.cpp
Distribution.cpp
DropUnitDims.cpp DropUnitDims.cpp
ElementwiseToLinalg.cpp ElementwiseToLinalg.cpp
Fusion.cpp Fusion.cpp
......
//===- Distibution.cpp - linalg named ops to generic ops --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the Linalg distibution pass. It updates `tiled_loop`
// control variables depending on the distribution type.
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#define DEBUG_TYPE "linalg-distribution"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
using namespace mlir::linalg;
namespace {
struct DistributeTiledLoopPattern
: public OpRewritePattern<linalg::TiledLoopOp> {
DistributeTiledLoopPattern(MLIRContext *context,
LinalgLoopDistributionOptions options,
LinalgTransformationFilter marker)
: OpRewritePattern<linalg::TiledLoopOp>(context), options(options),
marker(marker) {}
LogicalResult matchAndRewrite(linalg::TiledLoopOp op,
PatternRewriter &rewriter) const override {
if (failed(marker.checkAndNotify(rewriter, op)))
return failure();
if (!op.distribution_types().hasValue())
return failure();
Location loc = op.getLoc();
SmallVector<Value, 2> newLowerBounds = op.lowerBound();
SmallVector<Value, 2> newUpperBounds = op.upperBound();
SmallVector<Value, 2> newSteps = op.step();
// Update bounds and steps.
auto distributionTypes = op.distribution_types().getValue();
for (int i = 0, e = op.getNumLoops(); i < e; ++i) {
StringRef type = distributionTypes[i].cast<StringAttr>().getValue();
auto procInfoCallback = options.procInfoMap.find(type);
if (procInfoCallback == options.procInfoMap.end())
continue;
if (!isParallelIteratorType(op.iterator_types()[i])) {
op.emitOpError("only support for parallel loops is implemented");
return failure();
}
ProcInfo info = procInfoCallback->second(rewriter, loc);
updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs,
newLowerBounds[i], newUpperBounds[i],
newSteps[i]);
}
rewriter.updateRootInPlace(op, [&] {
op.setLowerBounds(newLowerBounds);
op.setUpperBounds(newUpperBounds);
op.setSteps(newSteps);
});
marker.replaceLinalgTransformationFilter(rewriter, op);
return success();
}
private:
LinalgLoopDistributionOptions options;
LinalgTransformationFilter marker;
};
} // namespace
void mlir::linalg::populateLinalgDistributeTiledLoopPattern(
RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
const LinalgTransformationFilter &marker) {
patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker);
}
// RUN: mlir-opt -test-linalg-distribution %s | FileCheck %s
func private @foo(%A: tensor<64x64xf32>,
%B: tensor<64x64xf32>) -> tensor<64x64xf32>
func @distribute_for_gpu(%A: tensor<64x64xf32>,
%B: tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = constant 0 : index
%c16 = constant 16 : index
%c64 = constant 64 : index
%c24 = constant 24 : index
%0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c64, %c64) step (%c24, %c16)
ins (%A_ = %A: tensor<64x64xf32>) outs (%B_ = %B:tensor<64x64xf32>)
distribution ["block_x", "block_y"] {
%0 = call @foo(%A_, %B_)
: (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
linalg.yield %0 : tensor<64x64xf32>
}
return %0 : tensor<64x64xf32>
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 24)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK-LABEL: func @distribute_for_gpu
// CHECK: %[[C64:.*]] = constant 64 : index
// CHECK-DAG: %[[GPU_BLOCK_X:.*]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[GPU_GRID_DIM_X:.*]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[LB_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_BLOCK_X]]]
// CHECK-DAG: %[[STEP_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_GRID_DIM_X]]]
// CHECK-DAG: %[[GPU_BLOCK_Y:.*]] = "gpu.block_id"() {dimension = "y"}
// CHECK-DAG: %[[GPU_GRID_DIM_Y:.*]] = "gpu.grid_dim"() {dimension = "y"}
// CHECK-DAG: %[[LB_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_BLOCK_Y]]]
// CHECK-DAG: %[[STEP_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_GRID_DIM_Y]]]
// CHECK: linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[LB_I]], %[[LB_J]])
// CHECK-SAME: to (%[[C64]], %[[C64]]) step (%[[STEP_I]], %[[STEP_J]])
//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for testing Linalg hoisting functions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
template <char dim>
static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) {
std::string d(1, dim);
StringAttr attr = b.getStringAttr(d);
Type indexType = b.getIndexType();
ProcInfo procInfo = {b.create<gpu::BlockIdOp>(loc, indexType, attr),
b.create<gpu::GridDimOp>(loc, indexType, attr)};
return procInfo;
}
static LinalgLoopDistributionOptions getDistributionOptions() {
LinalgLoopDistributionOptions opts;
opts.procInfoMap.insert(std::make_pair("block_x", getGpuBlockInfo<'x'>));
opts.procInfoMap.insert(std::make_pair("block_y", getGpuBlockInfo<'y'>));
return opts;
}
namespace {
struct TestLinalgDistribution
: public PassWrapper<TestLinalgDistribution, FunctionPass> {
TestLinalgDistribution() = default;
TestLinalgDistribution(const TestLinalgDistribution &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect>();
}
void runOnFunction() override;
};
} // namespace
void TestLinalgDistribution::runOnFunction() {
auto funcOp = getFunction();
OwningRewritePatternList distributeTiledLoopsPatterns(&getContext());
populateLinalgDistributeTiledLoopPattern(
distributeTiledLoopsPatterns, getDistributionOptions(),
LinalgTransformationFilter(
ArrayRef<Identifier>{},
{Identifier::get("distributed", funcOp.getContext())})
.addFilter([](Operation *op) {
return success(!op->getParentOfType<linalg::TiledLoopOp>());
}));
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(distributeTiledLoopsPatterns));
// Ensure we drop the marker in the end.
funcOp.walk([](LinalgOp op) {
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
}
namespace mlir {
namespace test {
void registerTestLinalgDistribution() {
PassRegistration<TestLinalgDistribution> testTestLinalgDistributionPass(
"test-linalg-distribution", "Test Linalg distribution.");
}
} // namespace test
} // namespace mlir
...@@ -77,6 +77,7 @@ void registerTestGpuParallelLoopMappingPass(); ...@@ -77,6 +77,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass(); void registerTestIRVisitorsPass();
void registerTestInterfaces(); void registerTestInterfaces();
void registerTestLinalgCodegenStrategy(); void registerTestLinalgCodegenStrategy();
void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion(); void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape(); void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms(); void registerTestLinalgFusionTransforms();
...@@ -156,6 +157,7 @@ void registerTestPasses() { ...@@ -156,6 +157,7 @@ void registerTestPasses() {
test::registerTestIRVisitorsPass(); test::registerTestIRVisitorsPass();
test::registerTestInterfaces(); test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy(); test::registerTestLinalgCodegenStrategy();
test::registerTestLinalgDistribution();
test::registerTestLinalgElementwiseFusion(); test::registerTestLinalgElementwiseFusion();
test::registerTestPushExpandingReshape(); test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms(); test::registerTestLinalgFusionTransforms();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment