Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Func] Delete DecomposeCallGraphTypes.cpp #117424

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

DecomposeCallGraphTypes.cpp was a workaround around missing 1:N support in the dialect conversion. Now that 1:N support was added, the workaround can be deleted. The test remains in place, as an example for how to write such a transformation with the dialect conversion framework.

Note for LLVM integration: If you are using DecomposeCallGraphTypes.cpp, switch to the patterns that are used in TestDecomposeCallGraphTypes.cpp.

@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2024

@llvm/pr-subscribers-mlir-func

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

DecomposeCallGraphTypes.cpp was a workaround around missing 1:N support in the dialect conversion. Now that 1:N support was added, the workaround can be deleted. The test remains in place, as an example for how to write such a transformation with the dialect conversion framework.

Note for LLVM integration: If you are using DecomposeCallGraphTypes.cpp, switch to the patterns that are used in TestDecomposeCallGraphTypes.cpp.


Full diff: https://github.com/llvm/llvm-project/pull/117424.diff

5 Files Affected:

  • (removed) mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h (-34)
  • (modified) mlir/lib/Dialect/Func/Transforms/CMakeLists.txt (-1)
  • (removed) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (-136)
  • (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+2-5)
  • (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+4-2)
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
deleted file mode 100644
index 1be406bf3adf92..00000000000000
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ /dev/null
@@ -1,34 +0,0 @@
-//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Conversion patterns for decomposing types along call graph edges. That is,
-// decomposing types for calls, returns, and function args.
-//
-// TODO: Make this handle dialect-defined functions, calls, and returns.
-// Currently, the generic interfaces aren't sophisticated enough for the
-// types of mutations that we are doing here.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-
-#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
-
-namespace mlir {
-
-/// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `TypeConverter`.
-void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
-                                             const TypeConverter &typeConverter,
-                                             RewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index f8fb1f436a95b1..6384d25ee70273 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIRFuncTransforms
-  DecomposeCallGraphTypes.cpp
   DuplicateFunctionElimination.cpp
   FuncConversions.cpp
   OneToNFuncConversions.cpp
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
deleted file mode 100644
index 03be00328bda33..00000000000000
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ /dev/null
@@ -1,136 +0,0 @@
-//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-
-using namespace mlir;
-using namespace mlir::func;
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForFuncArgs
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand function arguments according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForFuncArgs
-    : public OpConversionPattern<func::FuncOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    auto functionType = op.getFunctionType();
-
-    // Convert function arguments using the provided TypeConverter.
-    TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
-    for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
-      SmallVector<Type, 2> decomposedTypes;
-      if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
-        return failure();
-      if (!decomposedTypes.empty())
-        conversion.addInputs(argType.index(), decomposedTypes);
-    }
-
-    // If the SignatureConversion doesn't apply, bail out.
-    if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
-                                           &conversion)))
-      return failure();
-
-    // Update the signature of the function.
-    SmallVector<Type, 2> newResultTypes;
-    if (failed(typeConverter->convertTypes(functionType.getResults(),
-                                           newResultTypes)))
-      return failure();
-    rewriter.modifyOpInPlace(op, [&] {
-      op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
-                                          newResultTypes));
-    });
-    return success();
-  }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForReturnOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand return operands according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForReturnOp
-    : public OpConversionPattern<ReturnOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    SmallVector<Value, 2> newOperands;
-    for (ValueRange operand : adaptor.getOperands())
-      llvm::append_range(newOperands, operand);
-    rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
-    return success();
-  }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand call op operands and results according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-
-    // Create the operands list of the new `CallOp`.
-    SmallVector<Value, 2> newOperands;
-    for (ValueRange operand : adaptor.getOperands())
-      llvm::append_range(newOperands, operand);
-
-    // Create the new result types for the new `CallOp` and track the number of
-    // replacement types for each original op result.
-    SmallVector<Type, 2> newResultTypes;
-    SmallVector<unsigned> expandedResultSizes;
-    for (Type resultType : op.getResultTypes()) {
-      unsigned oldSize = newResultTypes.size();
-      if (failed(typeConverter->convertType(resultType, newResultTypes)))
-        return failure();
-      expandedResultSizes.push_back(newResultTypes.size() - oldSize);
-    }
-
-    CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
-                                               newResultTypes, newOperands);
-
-    // Build a replacement value for each result to replace its uses.
-    SmallVector<ValueRange> replacedValues;
-    replacedValues.reserve(op.getNumResults());
-    unsigned startIdx = 0;
-    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
-      ValueRange repl =
-          newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
-      replacedValues.push_back(repl);
-      startIdx += expandedResultSizes[i];
-    }
-    rewriter.replaceOpWithMultiple(op, replacedValues);
-    return success();
-  }
-};
-} // namespace
-
-void mlir::populateDecomposeCallGraphTypesPatterns(
-    MLIRContext *context, const TypeConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns
-      .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
-           DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
-}
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 9e7759bef6d8fd..d531960aa285d1 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -124,12 +124,9 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
   using OpConversionPattern<ReturnOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    // For a return, all operands go to the results of the parent, so
-    // rewrite them all.
-    rewriter.modifyOpInPlace(op,
-                             [&] { op->setOperands(adaptor.getOperands()); });
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, flattenValues(adaptor.getOperands()));
     return success();
   }
 };
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index de511c58ae6ee0..15c8bac61e38b0 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -9,7 +9,7 @@
 #include "TestDialect.h"
 #include "TestOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -142,7 +142,9 @@ struct TestDecomposeCallGraphTypes
     typeConverter.addArgumentMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
-    populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
 
     if (failed(applyPartialConversion(module, target, std::move(patterns))))
       return signalPassFailure();

Copy link

github-actions bot commented Nov 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/delete_decompose_call_graph branch 2 times, most recently from 5d6e8e4 to 4e4a5c8 Compare November 23, 2024 07:41
Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM!

Apply suggestions from code review

Co-authored-by: Markus Böck <[email protected]>

address comments

[WIP] 1:N conversion pattern

update test cases

Update mlir/lib/Transforms/Utils/DialectConversion.cpp

Co-authored-by: Markus Böck <[email protected]>

Update mlir/lib/Transforms/Utils/DialectConversion.cpp

Co-authored-by: Markus Böck <[email protected]>

address comments

rollback unresolved materializations properly
@matthias-springer matthias-springer force-pushed the users/matthias-springer/delete_decompose_call_graph branch from 4e4a5c8 to 5f1f245 Compare November 29, 2024 05:48
Base automatically changed from users/matthias-springer/1n_pattern to main November 30, 2024 00:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants