Skip to content

Commit

Permalink
Rename help function to getAllowedGenericOpOrGeneralizeNamedOp
Browse files Browse the repository at this point in the history
Signed-off-by: jerryyin <[email protected]>
  • Loading branch information
jerryyin committed Jan 21, 2025
1 parent acac51b commit d686eef
Showing 1 changed file with 7 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,19 @@ static void specializeGenericTransposeOp(RewriterBase &rewriter,
/// generalizing is allowed. Otherwise if the `op` is a linalg::GenericOp,
/// then just return the generic op.
static FailureOr<linalg::GenericOp>
getGenericOpOrGeneralizeContraction(RewriterBase &rewriter, Operation *op,
bool allowGeneralizing) {
getAllowedGenericOpOrGenerializeNamedOp(RewriterBase &rewriter, Operation *op,
bool allowGeneralizing) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp) {
return failure();
}
// TODO: Right now this is restricted to contractions due to fragility around
// handling of convolutions.
if (!isa<linalg::GenericOp>(linalgOp) &&
!(allowGeneralizing && linalg::isaContractionOpInterface(linalgOp))) {
if (linalg::isaConvolutionOpInterface(linalgOp)) {
return failure();
}

// If this is generic op but comply to convolution op interface, assume
// it is from ConvertConvToChannelsLast pass and skip.
// FuseTransposeWithProducerLinalgOp will fuse the transpose into successive
// convolution, negating the effect from the filter layout conversion from
// that pass.
if (isa<linalg::GenericOp>(linalgOp) &&
linalg::isaConvolutionOpInterface(linalgOp)) {
if (!isa<linalg::GenericOp>(linalgOp) &&
!(allowGeneralizing && linalg::isaContractionOpInterface(linalgOp))) {
return failure();
}

Expand Down Expand Up @@ -221,7 +214,7 @@ class FuseTransposeWithProducerLinalgOp
}

int64_t resultIndex = result.getResultNumber();
auto maybeGenericOp = getGenericOpOrGeneralizeContraction(
auto maybeGenericOp = getAllowedGenericOpOrGenerializeNamedOp(
rewriter, result.getOwner(), allowGeneralizing);
if (failed(maybeGenericOp)) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -628,7 +621,7 @@ class FuseTransposeWithLinalgOpConsumer
// To do the fusion, we can simply apply the permutation of the transpose
// to the results of the associated input's indexing map, and then forward
// the input to the transpose to the consumer generic.
auto maybeGenericOp = getGenericOpOrGeneralizeContraction(
auto maybeGenericOp = getAllowedGenericOpOrGenerializeNamedOp(
rewriter, linalgOp, allowGeneralizing);
if (failed(maybeGenericOp)) {
return failure();
Expand Down

0 comments on commit d686eef

Please sign in to comment.