From f54689d45185975442d7dbcf8ace2b42e2fbff96 Mon Sep 17 00:00:00 2001 From: "Tang, Jiajun" Date: Tue, 4 Jun 2024 18:34:03 +0800 Subject: [PATCH] Move dim3 ctor rule to EA. --- clang/lib/DPCT/ASTTraversal.cpp | 57 ++---------- clang/lib/DPCT/ASTTraversal.h | 3 - clang/lib/DPCT/ExprAnalysis.cpp | 101 ++++++++++++++++---- clang/lib/DPCT/TextModification.cpp | 137 ---------------------------- clang/lib/DPCT/TextModification.h | 33 ------- 5 files changed, 94 insertions(+), 237 deletions(-) diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 7396c5a6e95a..2491a8887a8d 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -2990,26 +2990,15 @@ void ReplaceDim3CtorRule::registerMatcher(MatchFinder &MF) { argumentCountIs(1), unless(hasAncestor(cxxConstructExpr( hasType(namedDecl(hasName("dim3"))))))) - .bind("dim3Top"), - this); - - MF.addMatcher(cxxConstructExpr( - hasType(namedDecl(hasName("dim3"))), argumentCountIs(3), - anyOf(hasParent(varDecl()), hasParent(exprWithCleanups())), - unless(hasParent(initListExpr())), - unless(hasAncestor( - cxxConstructExpr(hasType(namedDecl(hasName("dim3"))))))) - .bind("dim3CtorDecl"), + .bind("111"), this); MF.addMatcher(cxxConstructExpr(hasType(namedDecl(hasName("dim3"))), argumentCountIs(3), unless(hasParent(initListExpr())), - unless(hasParent(varDecl())), - unless(hasParent(exprWithCleanups())), unless(hasAncestor(cxxConstructExpr( hasType(namedDecl(hasName("dim3"))))))) - .bind("dim3CtorNoDecl"), + .bind("111"), this); MF.addMatcher( @@ -3020,41 +3009,15 @@ void ReplaceDim3CtorRule::registerMatcher(MatchFinder &MF) { this); } -ReplaceDim3Ctor *ReplaceDim3CtorRule::getReplaceDim3Modification( - const MatchFinder::MatchResult &Result) { - if (auto Ctor = getNodeAsType(Result, "dim3CtorDecl")) { - if(getParentKernelCall(Ctor)) - return nullptr; - // dim3 a; or dim3 a(1); - return new ReplaceDim3Ctor(Ctor, true /*isDecl*/); - } else if (auto Ctor = - getNodeAsType(Result, "dim3CtorNoDecl")) { - if(getParentKernelCall(Ctor)) - return nullptr; - // deflt = dim3(3); - return new ReplaceDim3Ctor(Ctor, false /*isDecl*/); - } else if (auto Ctor = getNodeAsType(Result, "dim3Top")) { - if(getParentKernelCall(Ctor)) - return nullptr; - // dim3 d3_6_3 = dim3(ceil(test.x + NUM), NUM + test.y, NUM + test.z + NUM); - if (auto A = ReplaceDim3Ctor::getConstructExpr(Ctor->getArg(0))) { - // strip the top CXXConstructExpr, if there's a CXXConstructExpr further - // down - return new ReplaceDim3Ctor(Ctor, A); - } else { - // Copy constructor case: dim3 a(copyfrom) - // No replacements are needed - return nullptr; - } - } - - return nullptr; -} - void ReplaceDim3CtorRule::runRule(const MatchFinder::MatchResult &Result) { - ReplaceDim3Ctor *R = getReplaceDim3Modification(Result); - if (R) { - emplaceTransformation(R); + if (auto Ctor = getNodeAsType(Result, "111")) { + if (getParentKernelCall(Ctor)) + return; + ExprAnalysis EA; + EA.analyze(Ctor); + emplaceTransformation(EA.getReplacement()); + EA.applyAllSubExprRepl(); + return; } if (auto TL = getNodeAsType(Result, "dim3Type")) { diff --git a/clang/lib/DPCT/ASTTraversal.h b/clang/lib/DPCT/ASTTraversal.h index 694fc93988a2..c66f2defabb5 100644 --- a/clang/lib/DPCT/ASTTraversal.h +++ b/clang/lib/DPCT/ASTTraversal.h @@ -581,9 +581,6 @@ class VectorTypeOperatorRule }; class ReplaceDim3CtorRule : public NamedMigrationRule { - ReplaceDim3Ctor *getReplaceDim3Modification( - const ast_matchers::MatchFinder::MatchResult &Result); - public: void registerMatcher(ast_matchers::MatchFinder &MF) override; void runRule(const ast_matchers::MatchFinder::MatchResult &Result); diff --git a/clang/lib/DPCT/ExprAnalysis.cpp b/clang/lib/DPCT/ExprAnalysis.cpp index b5aec8619769..3017048ff95b 100644 --- a/clang/lib/DPCT/ExprAnalysis.cpp +++ b/clang/lib/DPCT/ExprAnalysis.cpp @@ -16,6 +16,8 @@ #include "DNNAPIMigration.h" #include "MemberExprRewriter.h" #include "TypeLocRewriters.h" +#include "Utility.h" +#include "clang/AST/Decl.h" #include "clang/AST/DeclTemplate.h" #include "clang/AST/Expr.h" #include "clang/AST/ExprConcepts.h" @@ -632,12 +634,39 @@ void ExprAnalysis::analyzeExpr(const CXXTemporaryObjectExpr *Temp) { analyzeExpr(static_cast(Temp)); } +const CXXConstructExpr *getConstructExpr(const Expr *E) { + if (const auto *C = dyn_cast_or_null(E)) { + return C; + } else if (isa(E)) { + return getConstructExpr( + dyn_cast(E)->getSubExpr()); + } else if (isa(E)) { + return getConstructExpr(dyn_cast(E)->getSubExpr()); + } else { + return nullptr; + } +} + void ExprAnalysis::analyzeExpr(const CXXConstructExpr *Ctor) { if (Ctor->getConstructor()->getDeclName().getAsString() == "dim3") { + // strip the top CXXConstructExpr, if there's a CXXConstructExpr further + // down + if (Ctor->getNumArgs() == 1) { + Ctor = getConstructExpr(Ctor->getArg(0)); + } + if (Ctor == nullptr) { + return; + } + auto Parents = dpct::DpctGlobalInfo::getContext().getParents(*Ctor); + auto IsDecl = + Parents.size() == 1 && (Parents[0].get() != nullptr || + Parents[0].get() != nullptr); std::string ArgsString; llvm::raw_string_ostream OS(ArgsString); - DpctGlobalInfo::printCtadClass(OS, MapNames::getClNamespace() + "range", 3) - << "("; + if (!IsDecl) + DpctGlobalInfo::printCtadClass(OS, MapNames::getClNamespace() + "range", + 3) + << "("; ArgumentAnalysis A; std::string ArgStr = ""; for (auto Arg : Ctor->arguments()) { @@ -645,23 +674,61 @@ void ExprAnalysis::analyzeExpr(const CXXConstructExpr *Ctor) { ArgStr = ", " + A.getReplacedString() + ArgStr; } ArgStr.replace(0, 2, ""); - OS << ArgStr << ")"; + OS << ArgStr; + if (!IsDecl) + OS << ")"; OS.flush(); - // Special handling for implicit ctor. - // #define GET_BLOCKS(a) a - // dim3 A = GET_BLOCKS(1); - // Result if using SM.getExpansionRange: - // sycl::range<3> A = sycl::range<3>(1, 1, GET_BLOCKS(1)); - // Result if using addReplacement(E): - // #define GET_BLOCKS(a) sycl::range<3>(1, 1, a) - // sycl::range<3> A = GET_BLOCKS(1); - if (Ctor->getParenOrBraceRange().isInvalid() && isOuterMostMacro(Ctor)) { - return addReplacement( - SM.getExpansionRange(Ctor->getBeginLoc()).getBegin(), - SM.getExpansionRange(Ctor->getEndLoc()).getEnd(), ArgsString); - } - addReplacement(Ctor, ArgsString); + CharSourceRange CSR; + if (IsDecl) { + SourceRange SR = Ctor->getParenOrBraceRange(); + if (SR.isInvalid()) { + // convert to spelling location if the dim3 constructor is in a macro + // otherwise, Lexer::getLocForEndOfToken returns invalid source location + auto CtorLoc = Ctor->getLocation().isMacroID() + ? SM.getSpellingLoc(Ctor->getLocation()) + : Ctor->getLocation(); + // dim3 a; + // MACRO(... dim3 a; ...) + auto CtorEndLoc = Lexer::getLocForEndOfToken( + CtorLoc, 0, SM, DpctGlobalInfo::getContext().getLangOpts()); + CSR = CharSourceRange(SourceRange(CtorEndLoc, CtorEndLoc), false); + ArgsString = "(" + ArgsString + ")"; + } else { + SourceRange SR1 = + SourceRange(SR.getBegin().getLocWithOffset(1), SR.getEnd()); + CSR = CharSourceRange(SR1, false); + } + } else { + // adjust the statement to replace if top-level constructor includes the + // variable being defined + const Stmt *S = Ctor; + if (!S) { + return; + } + if (S->getBeginLoc().isMacroID() && !isOuterMostMacro(S)) { + auto Range = getDefinitionRange(S->getBeginLoc(), S->getEndLoc()); + auto Begin = Range.getBegin(); + auto End = Range.getEnd(); + End = End.getLocWithOffset(Lexer::MeasureTokenLength( + End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts())); + CSR = CharSourceRange::getTokenRange(Begin, End); + } else { + // Use getStmtExpansionSourceRange(S) to support cases like + // dim3 a = MACRO; + auto Range = getStmtExpansionSourceRange(S); + auto Begin = Range.getBegin(); + auto End = Range.getEnd(); + CSR = CharSourceRange::getTokenRange( + Begin, + End.getLocWithOffset(Lexer::MeasureTokenLength( + End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts()))); + } + } + auto Range = getDefinitionRange(CSR.getBegin(), CSR.getEnd()); + auto Length = SM.getDecomposedLoc(Range.getEnd()).second - + SM.getDecomposedLoc(Range.getBegin()).second; + addReplacement(Range.getBegin(), Length, ArgsString); return; } for (auto It = Ctor->arg_begin(); It != Ctor->arg_end(); It++) { diff --git a/clang/lib/DPCT/TextModification.cpp b/clang/lib/DPCT/TextModification.cpp index 0010962e66da..a11c402a744c 100644 --- a/clang/lib/DPCT/TextModification.cpp +++ b/clang/lib/DPCT/TextModification.cpp @@ -523,135 +523,6 @@ ReplaceInclude::getReplacement(const ASTContext &Context) const { this); } -void ReplaceDim3Ctor::setRange() { - auto &SM = DpctGlobalInfo::getSourceManager(); - if (isDecl) { - SourceRange SR = Ctor->getParenOrBraceRange(); - if (SR.isInvalid()) { - // convert to spelling location if the dim3 constructor is in a macro - // otherwise, Lexer::getLocForEndOfToken returns invalid source location - auto CtorLoc = Ctor->getLocation().isMacroID() - ? SM.getSpellingLoc(Ctor->getLocation()) - : Ctor->getLocation(); - // dim3 a; - // MACRO(... dim3 a; ...) - auto CtorEndLoc = Lexer::getLocForEndOfToken( - CtorLoc, 0, SM, DpctGlobalInfo::getContext().getLangOpts()); - CSR = CharSourceRange(SourceRange(CtorEndLoc, CtorEndLoc), false); - } else { - SourceRange SR1 = - SourceRange(SR.getBegin().getLocWithOffset(1), SR.getEnd()); - CSR = CharSourceRange(SR1, false); - } - } else { - // adjust the statement to replace if top-level constructor includes the - // variable being defined - const Stmt *S = getReplaceStmt(Ctor); - if (!S) { - return; - } - if (S->getBeginLoc().isMacroID() && !isOuterMostMacro(S)) { - auto Range = getDefinitionRange(S->getBeginLoc(), S->getEndLoc()); - auto Begin = Range.getBegin(); - auto End = Range.getEnd(); - End = End.getLocWithOffset(Lexer::MeasureTokenLength( - End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts())); - CSR = CharSourceRange::getTokenRange(Begin, End); - } else { - // Use getStmtExpansionSourceRange(S) to support cases like - // dim3 a = MACRO; - auto Range = getStmtExpansionSourceRange(S); - auto Begin = Range.getBegin(); - auto End = Range.getEnd(); - CSR = CharSourceRange::getTokenRange( - Begin, - End.getLocWithOffset(Lexer::MeasureTokenLength( - End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts()))); - } - } -} - -ReplaceInclude *ReplaceDim3Ctor::getEmpty() { - return new ReplaceInclude(CSR, ""); -} - -// Strips possible Materialize and Cast operators from CXXConstructor -const CXXConstructExpr *ReplaceDim3Ctor::getConstructExpr(const Expr *E) { - if (auto C = dyn_cast_or_null(E)) { - return C; - } else if (isa(E)) { - return getConstructExpr( - dyn_cast(E)->getSubExpr()); - } else if (isa(E)) { - return getConstructExpr(dyn_cast(E)->getSubExpr()); - } else { - return nullptr; - } -} - -// Returns the full replacement string for the CXXConstructorExpr -std::string -ReplaceDim3Ctor::getSyclRangeCtor(const CXXConstructExpr *Ctor) const { - ExprAnalysis Analysis(Ctor); - return Analysis.getReplacedString(); -} - -const Stmt *ReplaceDim3Ctor::getReplaceStmt(const Stmt *S) const { - if (auto Ctor = dyn_cast_or_null(S)) { - if (Ctor->getNumArgs() == 1) { - return getConstructExpr(Ctor->getArg(0)); - } - } - return S; -} - -std::string ReplaceDim3Ctor::getReplaceString() const { - if (isDecl) { - // Get the new parameter list for the replaced constructor, without the - // parens - std::string ReplacedString; - llvm::raw_string_ostream OS(ReplacedString); - ArgumentAnalysis AA; - std::string ArgStr = ""; - for (auto Arg : Ctor->arguments()) { - AA.analyze(Arg); - ArgStr = ", " + AA.getReplacedString() + ArgStr; - } - ArgStr.replace(0, 2, ""); - OS << ArgStr; - OS.flush(); - if (Ctor->getParenOrBraceRange().isInvalid()) { - // dim3 = a; - ReplacedString = "(" + ReplacedString + ")"; - } - return ReplacedString; - } else { - std::string S; - if (FinalCtor) { - S = getSyclRangeCtor(FinalCtor); - } else { - S = getSyclRangeCtor(Ctor); - } - return S; - } -} - -std::shared_ptr -ReplaceDim3Ctor::getReplacement(const ASTContext &Context) const { - if (this->isIgnoreTM()) - return nullptr; - // Use getDefinitionRange in general cases, - // For cases like dim3 a = MACRO; - // CSR is already set to the expansion range. - auto &SM = dpct::DpctGlobalInfo::getSourceManager(); - ReplacementString = getReplaceString(); - auto Range = getDefinitionRange(CSR.getBegin(), CSR.getEnd()); - auto Length = SM.getDecomposedLoc(Range.getEnd()).second - - SM.getDecomposedLoc(Range.getBegin()).second; - return std::make_shared(SM, Range.getBegin(), Length, - getReplaceString(), this); -} - std::shared_ptr InsertComment::getReplacement(const ASTContext &Context) const { if (this->isIgnoreTM()) @@ -930,14 +801,6 @@ void ReplaceInclude::print(llvm::raw_ostream &OS, ASTContext &Context, printReplacement(OS, T); } -void ReplaceDim3Ctor::print(llvm::raw_ostream &OS, ASTContext &Context, - const bool PrintDetail) const { - printHeader(OS, getID(), PrintDetail ? getParentRuleName() : StringRef()); - printLocation(OS, CSR.getBegin(), Context, PrintDetail); - Ctor->printPretty(OS, nullptr, PrintingPolicy(Context.getLangOpts())); - printReplacement(OS, ReplacementString); -} - void InsertComment::print(llvm::raw_ostream &OS, ASTContext &Context, const bool PrintDetail) const { printHeader(OS, getID(), PrintDetail ? getParentRuleName() : StringRef()); diff --git a/clang/lib/DPCT/TextModification.h b/clang/lib/DPCT/TextModification.h index b828d9c1883d..6acca115e5d6 100644 --- a/clang/lib/DPCT/TextModification.h +++ b/clang/lib/DPCT/TextModification.h @@ -529,39 +529,6 @@ class ReplaceInclude : public TextModification { const bool PrintDetail = true) const override; }; -/// Replace Dim3 constructors -class ReplaceDim3Ctor : public TextModification { - bool isDecl; - const CXXConstructExpr *Ctor; - const CXXConstructExpr *FinalCtor; - CharSourceRange CSR; - mutable std::string ReplacementString; - - void setRange(); - const Stmt *getReplaceStmt(const Stmt *S) const; - std::string getSyclRangeCtor(const CXXConstructExpr *Ctor) const; - std::string getReplaceString() const; - -public: - ReplaceDim3Ctor(const CXXConstructExpr *_Ctor, bool _isDecl = false) - : TextModification(TMID::ReplaceDim3Ctor, G2), isDecl(_isDecl), - Ctor(_Ctor), FinalCtor(nullptr) { - setRange(); - } - ReplaceDim3Ctor(const CXXConstructExpr *_Ctor, - const CXXConstructExpr *_FinalCtor) - : TextModification(TMID::ReplaceDim3Ctor, G2), isDecl(false), Ctor(_Ctor), - FinalCtor(_FinalCtor) { - setRange(); - } - static const CXXConstructExpr *getConstructExpr(const Expr *E); - ReplaceInclude *getEmpty(); - std::shared_ptr - getReplacement(const ASTContext &Context) const override; - void print(llvm::raw_ostream &OS, ASTContext &Context, - const bool PrintDetail = true) const override; -}; - class InsertBeforeStmt : public TextModification { const Stmt *S; std::string T;