Skip to content

Commit

Permalink
[Codegen] Harden yielding logic in TileDispatchUsingForall (iree-org#…
Browse files Browse the repository at this point in the history
…19212)

Currently the TileDispatchUsingForall pass chooses to yield replacements
for fused producers based on dominance relative to the tiling root. Now
that this pass includes consumer fusion, this can create situations
where a single operation is consuming two loop yielded values, blocking
fusion.

This patch disables yielding of values when there is more than one
tilable consumer. As a result, it's common to have remaining producer
fusion opportunities after fusing in consumers, so we add another
iteration of producer fusion. In the future the fusion worklist needs to
include both producer and consumer fusions to enable covering all cases.
  • Loading branch information
qedawkins authored Nov 21, 2024
1 parent dd4c91c commit 81ca183
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 7 deletions.
100 changes: 93 additions & 7 deletions compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,12 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
// Pass implementation.
//===---------------------------------------------------------------------===//

static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {

// Fuse all consumers of the given `tiledOp` into the surrounding scf.forall.
// Returns a list of new `tensor.extract_slice` ops with new fusion
// opportunities, as well as the new surrounding `scf.forall` (because consumer
// fusion replaces the loop).
static std::pair<std::queue<Operation *>, scf::ForallOp>
fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
auto addCandidateSlices =
[](Operation *fusedOp,
std::queue<tensor::ParallelInsertSliceOp> &candidates) {
Expand All @@ -283,6 +287,8 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
std::queue<tensor::ParallelInsertSliceOp> candidates;
addCandidateSlices(tiledOp, candidates);

std::queue<Operation *> newFusionOpportunities;
scf::ForallOp newLoop = tiledOp->getParentOfType<scf::ForallOp>();
while (!candidates.empty()) {

// Traverse the slices in BFS fashion.
Expand All @@ -301,11 +307,63 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
rewriter.replaceOp(fusedResult->origConsumerOperand->getOwner(),
fusedResult->tiledOps.front());

// The result of the fused conumers might themselved be slices of
// The result of the fused consumers might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
candidates);

// Add the list of new producer fusion opportunities.
for (auto tiledOp : fusedResult.value().tiledOps) {
for (auto operand : tiledOp->getOperands()) {
if (auto sliceProducer =
operand.getDefiningOp<tensor::ExtractSliceOp>()) {
if (llvm::isa_and_present<TilingInterface>(
sliceProducer.getSource().getDefiningOp())) {
newFusionOpportunities.push(sliceProducer);
}
}
}
// Store the new loop for follow up producer fusion.
newLoop = tiledOp->getParentOfType<scf::ForallOp>();
}
}
return std::make_pair(newFusionOpportunities, newLoop);
}

static void fuseProducersOfSlices(RewriterBase &rewriter,
std::queue<Operation *> &worklist,
scf::SCFTileAndFuseOptions &options,
scf::ForallOp forallOp) {
SmallVector<LoopLikeOpInterface> loops = {
cast<LoopLikeOpInterface>(&*forallOp)};
while (!worklist.empty()) {
auto candidateSlice = cast<tensor::ExtractSliceOp>(worklist.front());
worklist.pop();

auto fusableProducer =
candidateSlice.getSource().getDefiningOp<TilingInterface>();
if (!fusableProducer)
continue;

std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice,
cast<OpResult>(candidateSlice.getSource()),
/*destinationInitArg=*/false);
if (!controlFnResult)
continue;

// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
scf::tileAndFuseProducerOfSlice(rewriter, candidateSlice, loops);
if (!fusedResult)
continue;

for (auto newSlice : fusedResult->generatedSlices) {
worklist.push(newSlice);
}
}
}

Expand All @@ -319,6 +377,7 @@ static void collectTiledAndFusedOps(Operation *rootOp,
result.insert(rootOp);
while (!worklist.empty()) {
Operation *current = worklist.pop_back_val();
// Collect all tilable producers.
for (OpOperand &operand : current->getOpOperands()) {
Operation *producer = operand.get().getDefiningOp();
if (!producer || !isa<TilingInterface>(producer) ||
Expand All @@ -327,6 +386,16 @@ static void collectTiledAndFusedOps(Operation *rootOp,
worklist.push_back(producer);
result.insert(producer);
}
// Collect all tilable consumers.
for (auto user : current->getUsers()) {
if (result.count(user)) {
continue;
}
if (isa<TilingInterface>(user)) {
worklist.push_back(user);
result.insert(user);
}
}
}
}

Expand All @@ -352,9 +421,18 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {

llvm::DenseSet<Operation *> yieldReplacementsFor;
for (auto op : tiledAndFusedOps) {
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return dominanceInfo.properlyDominates(tilableOp, user);
})) {
// Yield a replacement if:
// a) All users of fused op are dominated by the tiling root.
// b) There is at most a single tiled user. If there is more than one
// then yielding a replacement may result in multiple incompatible
// consumer fusions.
if (llvm::any_of(op->getUsers(),
[&](Operation *user) {
return dominanceInfo.properlyDominates(tilableOp, user);
}) &&
(llvm::count_if(op->getUsers(), [&](Operation *user) {
return tiledAndFusedOps.contains(user);
}) < 2)) {
yieldReplacementsFor.insert(op);
}
}
Expand Down Expand Up @@ -429,7 +507,15 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
}

if (rootTiledOp) {
fuseConsumers(rewriter, rootTiledOp);
auto [newFusionOpportunities, newLoop] =
fuseConsumers(rewriter, rootTiledOp);

// Because we restrict to at most a single tilable consumer for yielding
// a replacement, no new fusion opportunities will yield a replacement,
// meaning there is no need to run consumer fusion again afterwards.
// TODO: run producer and consumer fusion in one worklist.
fuseProducersOfSlices(rewriter, newFusionOpportunities,
tileAndFuseOptions, newLoop);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,156 @@ func.func @multi_result(%arg0: tensor<64x128xf32>, %arg1: tensor<128x256xf32>, %
// CHECK: tensor.parallel_insert_slice
// CHECK: tensor.parallel_insert_slice
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @multi_use_producer_no_yield_replacement(%7: tensor<12x197x197xf32>) -> tensor<12x197x197xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant -3.40282347E+38 : f32
%8 = tensor.empty() : tensor<12x197x197xf32>
%9 = tensor.empty() : tensor<12x197xf32>
%10 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<12x197xf32>) -> tensor<12x197xf32>
%11 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%7 : tensor<12x197x197xf32>) outs(%10 : tensor<12x197xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.maxnumf %in, %out : f32
linalg.yield %15 : f32
} -> tensor<12x197xf32>
%12 = linalg.fill ins(%cst : f32) outs(%9 : tensor<12x197xf32>) -> tensor<12x197xf32>
%13 = linalg.generic {
indexing_maps = [#map, #map1, #map1],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%7, %11 : tensor<12x197x197xf32>, tensor<12x197xf32>)
outs(%12 : tensor<12x197xf32>) attrs = {
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4, 8, 0]]>} {
^bb0(%in: f32, %in_1: f32, %out: f32):
%15 = arith.subf %in, %in_1 : f32
%16 = math.exp %15 : f32
%17 = arith.addf %16, %out : f32
linalg.yield %17 : f32
} -> tensor<12x197xf32>
%14:2 = linalg.generic {
indexing_maps = [#map, #map1, #map1, #map, #map],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%7, %11, %13 : tensor<12x197x197xf32>, tensor<12x197xf32>, tensor<12x197xf32>)
outs(%8, %8 : tensor<12x197x197xf32>, tensor<12x197x197xf32>) {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32):
%15 = arith.subf %in, %in_1 : f32
%16 = math.exp %15 : f32
%17 = arith.divf %16, %in_2 : f32
linalg.yield %16, %17 : f32, f32
} -> (tensor<12x197x197xf32>, tensor<12x197x197xf32>)
return %14#1 : tensor<12x197x197xf32>
}

// CHECK-LABEL: func @multi_use_producer_no_yield_replacement(
// CHECK: %[[RESULT:.+]] = scf.forall
// CHECK: %[[MAX:.+]] = linalg.generic
// CHECK: arith.maxnumf
// CHECK: %[[EXPSUM:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.*}}, %[[MAX]]
// CHECK: arith.subf
// CHECK: math.exp
// CHECK: arith.addf
// CHECK: %[[EXPDIV:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.*}}, %[[MAX]], %[[EXPSUM]]
// CHECK: arith.subf
// CHECK: math.exp
// CHECK: arith.divf
// CHECK: return %[[RESULT]]

// -----

// Fusion of the following graph, root marked with [brackets].
// A
// / \
// B [C]
// \ /
// D
#map = affine_map<(d0) -> (d0)>
func.func @diamond_graph(%0: tensor<12xf32>, %1: tensor<12xf32>) -> tensor<12xf32> {
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty() : tensor<12xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
} ins(%0, %1 : tensor<12xf32>, tensor<12xf32>) outs(%2 : tensor<12xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.addf %in, %in_0 : f32
linalg.yield %7 : f32
} -> tensor<12xf32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
} ins(%3, %0 : tensor<12xf32>, tensor<12xf32>) outs(%2 : tensor<12xf32>) attrs = {
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4]]>} {
^bb0(%in: f32, %in_1: f32, %out: f32):
%8 = arith.addf %in, %in_1 : f32
linalg.yield %8 : f32
} -> tensor<12xf32>
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
} ins(%3, %1 : tensor<12xf32>, tensor<12xf32>) outs(%2 : tensor<12xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%9 = arith.addf %in, %in_0 : f32
linalg.yield %9 : f32
} -> tensor<12xf32>
%6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
} ins(%4, %5 : tensor<12xf32>, tensor<12xf32>) outs(%2 : tensor<12xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%10 = arith.addf %in, %in_0 : f32
linalg.yield %10 : f32
} -> tensor<12xf32>
return %6 : tensor<12xf32>
}

// CHECK-LABEL: func @diamond_graph(
// CHECK: %[[RESULT:.+]] = scf.forall
// CHECK: %[[TOP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[IN0_SLICE:.+]], %[[IN1_SLICE:.+]]
// CHECK-DAG: %[[LEFT:.+]] = linalg.generic {{.*}} ins(%[[TOP]], %[[IN0_SLICE]]
// CHECK-DAG: %[[RIGHT:.+]] = linalg.generic {{.*}} ins(%[[TOP]], %[[IN1_SLICE]]
// CHECK: linalg.generic {{.*}} ins(%[[LEFT]], %[[RIGHT]]
// CHECK: return %[[RESULT]]

// -----

// Fusion of the following graph, root marked with [brackets].
// [A] B
// \ /
// C
#map = affine_map<(d0) -> (d0)>
func.func @v_shaped_graph(%0: tensor<12xf32>, %1: tensor<12xf32>) -> tensor<12xf32> {
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty() : tensor<12xf32>
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]
} ins(%0 : tensor<12xf32>) outs(%2 : tensor<12xf32>) attrs = {
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4]]>} {
^bb0(%in: f32, %out: f32):
%6 = math.sqrt %in : f32
linalg.yield %6 : f32
} -> tensor<12xf32>
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]
} ins(%1 : tensor<12xf32>) outs(%2 : tensor<12xf32>) {
^bb0(%in: f32, %out: f32):
%7 = math.sqrt %in : f32
linalg.yield %7 : f32
} -> tensor<12xf32>
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
} ins(%3, %4 : tensor<12xf32>, tensor<12xf32>) outs(%2 : tensor<12xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%8 = arith.addf %in, %in_0 : f32
linalg.yield %8 : f32
} -> tensor<12xf32>
return %5 : tensor<12xf32>
}

// CHECK-LABEL: func @v_shaped_graph(
// CHECK-SAME: %[[IN0:[A-Za-z0-9]+]]: tensor<12xf32>
// CHECK-SAME: %[[IN1:[A-Za-z0-9]+]]: tensor<12xf32>
// CHECK: %[[RESULT:.+]] = scf.forall
// CHECK-DAG: %[[SLICE0:.+]] = tensor.extract_slice %[[IN0]]
// CHECK-DAG: %[[SLICE1:.+]] = tensor.extract_slice %[[IN1]]
// CHECK-DAG: %[[LEFT:.+]] = linalg.generic {{.*}} ins(%[[SLICE0]]
// CHECK-DAG: %[[RIGHT:.+]] = linalg.generic {{.*}} ins(%[[SLICE1]]
// CHECK: linalg.generic {{.*}} ins(%[[LEFT]], %[[RIGHT]]
// CHECK: return %[[RESULT]]

0 comments on commit 81ca183

Please sign in to comment.