Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Add pattern to fuse tensor.collapse_shape into forall producer #19295

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand All @@ -23,7 +21,6 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-gpu-fuse-and-hoist-parallel-loops"
Expand Down Expand Up @@ -331,6 +328,24 @@ struct FuseTilableForallConsumers final
}
};

struct FuseCollapseShapeConsumers final
: OpRewritePattern<tensor::CollapseShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
auto forallOp = collapseOp.getSrc().getDefiningOp<scf::ForallOp>();
if (!forallOp) {
return rewriter.notifyMatchFailure(collapseOp, "No forall op producer");
}

if (failed(fuseCollapseShapeIntoProducerForall(rewriter, forallOp,
collapseOp))) {
return failure();
}
return success();
}
};

void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

Expand Down Expand Up @@ -375,6 +390,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseTilableDestinationProducers>(context);
patterns.add<FuseUnitLoopDestination>(context);
patterns.add<FuseTilableForallConsumers>(context);
patterns.add<FuseCollapseShapeConsumers>(context);
populateSwapExtractWithExpandPattern(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,3 +546,57 @@ func.func @fuse_imperfectly_aligned_unpack(%arg0: tensor<5x31xf16>, %arg1: index
// CHECK: linalg.copy
// CHECK: scf.forall.in_parallel
// CHECK: return

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_non_contiguous_collapse_shape(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
%0 = tensor.empty() : tensor<8x8xf32>
%1 = scf.forall (%arg1) in (4) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
%2 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32>
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32>
%3 = linalg.copy ins(%extracted_slice : tensor<2x7xf32>) outs(%extracted_slice_0 : tensor<2x7xf32>) -> tensor<2x7xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [2, 7] [1, 1] : tensor<2x7xf32> into tensor<8x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
return %collapsed : tensor<64xf32>
}

// CHECK-LABEL: func @no_fuse_non_contiguous_collapse_shape
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2x7xf32> into tensor<8x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
%0 = tensor.empty() : tensor<8x8xf32>
%1 = scf.forall (%arg1) in (8) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
%2 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32>
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32>
%3 = linalg.copy ins(%extracted_slice : tensor<8xf32>) outs(%extracted_slice_0 : tensor<8xf32>) -> tensor<8xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [1, 8] [1, 1] : tensor<8xf32> into tensor<8x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
return %collapsed : tensor<64xf32>
}

// CHECK-LABEL: func @no_fuse_collapse_shape_rank_reduced
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<8xf32> into tensor<8x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,54 @@ void transform_dialect::FuseForallOp::getEffects(
transform::modifiesPayload(effects);
}

//===---------------------------------------------------------------------===//
// FuseCollapseShapeIntoForallOp
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform_dialect::FuseCollapseShapeIntoForallOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto producers = state.getPayloadOps(getProducer());
auto consumers = state.getPayloadOps(getConsumer());

int64_t numProducers = llvm::range_size(producers);
int64_t numConsumers = llvm::range_size(consumers);
if (numProducers != 1 || numConsumers != 1) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"More than one producer or consumer");
}

auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
if (!producer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-forall producer");
}
auto consumer = dyn_cast<tensor::CollapseShapeOp>(*consumers.begin());
if (!consumer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-collapse_shape consumer");
}

FailureOr<scf::ForallOp> fusedForallOp =
GPU::fuseCollapseShapeIntoProducerForall(rewriter, producer, consumer);
if (failed(fusedForallOp)) {
return mlir::emitSilenceableFailure(state.getTopLevel(),
"failed to fuse collapse_shape op");
}

results.set(getOperation()->getOpResult(0), {fusedForallOp.value()});
return DiagnosedSilenceableFailure::success();
}

void transform_dialect::FuseCollapseShapeIntoForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getProducerMutable(), effects);
transform::consumesHandle(getConsumerMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

} // namespace mlir::iree_compiler::IREE

void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,38 @@ def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

def FuseCollapseShapeIntoForallOp : Op<Transform_Dialect, "iree.fuse_collapse_shape_into_forall",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses a consumer tensor.collapse_shape op into a producer scf.forall op.
The users of the block argument for the corresponding forall output operand
should be only a tensor.parallel_insert_slice op, and tensor.extract_slice
ops that extract an equivalent subset. After the fusion, the output of the
forall will be collapsed, and all users of this block arg will also be
collapsed. Additional tensor.expand_shape ops will be inserted after any
tensor.extract_slice users inside the forall so that types match. Similarly,
a tensor.collapse_shape will be inserted before the
tensor.parallel_insert_slice.

#### Return modes
Emits a definite failure if either the producer is not an scf.forall op or
if the consumer is not a tensor.collapse_shape op.
}];

let arguments = (
ins TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$consumer
);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
$consumer `into` $producer attr-dict
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"drop_multi_mma_unit_dims.mlir",
"lower_multi_mma.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_collapse_shape_with_forall.mlir",
"transform_fuse_forall.mlir",
"transform_lower_barrier_region.mlir",
"vectorize_iree_gpu_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"drop_multi_mma_unit_dims.mlir"
"lower_multi_mma.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_collapse_shape_with_forall.mlir"
"transform_fuse_forall.mlir"
"transform_lower_barrier_region.mlir"
"unroll_multi_mma.mlir"
Expand Down
Loading
Loading