Skip to content

Commit

Permalink
Move dim3 ctor rule to EA.
Browse files Browse the repository at this point in the history
  • Loading branch information
tangjj11 committed Jun 4, 2024
1 parent c96bebb commit f54689d
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 237 deletions.
57 changes: 10 additions & 47 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2990,26 +2990,15 @@ void ReplaceDim3CtorRule::registerMatcher(MatchFinder &MF) {
argumentCountIs(1),
unless(hasAncestor(cxxConstructExpr(
hasType(namedDecl(hasName("dim3")))))))
.bind("dim3Top"),
this);

MF.addMatcher(cxxConstructExpr(
hasType(namedDecl(hasName("dim3"))), argumentCountIs(3),
anyOf(hasParent(varDecl()), hasParent(exprWithCleanups())),
unless(hasParent(initListExpr())),
unless(hasAncestor(
cxxConstructExpr(hasType(namedDecl(hasName("dim3")))))))
.bind("dim3CtorDecl"),
.bind("111"),
this);

MF.addMatcher(cxxConstructExpr(hasType(namedDecl(hasName("dim3"))),
argumentCountIs(3),
unless(hasParent(initListExpr())),
unless(hasParent(varDecl())),
unless(hasParent(exprWithCleanups())),
unless(hasAncestor(cxxConstructExpr(
hasType(namedDecl(hasName("dim3")))))))
.bind("dim3CtorNoDecl"),
.bind("111"),
this);

MF.addMatcher(
Expand All @@ -3020,41 +3009,15 @@ void ReplaceDim3CtorRule::registerMatcher(MatchFinder &MF) {
this);
}

ReplaceDim3Ctor *ReplaceDim3CtorRule::getReplaceDim3Modification(
const MatchFinder::MatchResult &Result) {
if (auto Ctor = getNodeAsType<CXXConstructExpr>(Result, "dim3CtorDecl")) {
if(getParentKernelCall(Ctor))
return nullptr;
// dim3 a; or dim3 a(1);
return new ReplaceDim3Ctor(Ctor, true /*isDecl*/);
} else if (auto Ctor =
getNodeAsType<CXXConstructExpr>(Result, "dim3CtorNoDecl")) {
if(getParentKernelCall(Ctor))
return nullptr;
// deflt = dim3(3);
return new ReplaceDim3Ctor(Ctor, false /*isDecl*/);
} else if (auto Ctor = getNodeAsType<CXXConstructExpr>(Result, "dim3Top")) {
if(getParentKernelCall(Ctor))
return nullptr;
// dim3 d3_6_3 = dim3(ceil(test.x + NUM), NUM + test.y, NUM + test.z + NUM);
if (auto A = ReplaceDim3Ctor::getConstructExpr(Ctor->getArg(0))) {
// strip the top CXXConstructExpr, if there's a CXXConstructExpr further
// down
return new ReplaceDim3Ctor(Ctor, A);
} else {
// Copy constructor case: dim3 a(copyfrom)
// No replacements are needed
return nullptr;
}
}

return nullptr;
}

void ReplaceDim3CtorRule::runRule(const MatchFinder::MatchResult &Result) {
ReplaceDim3Ctor *R = getReplaceDim3Modification(Result);
if (R) {
emplaceTransformation(R);
if (auto Ctor = getNodeAsType<CXXConstructExpr>(Result, "111")) {
if (getParentKernelCall(Ctor))
return;
ExprAnalysis EA;
EA.analyze(Ctor);
emplaceTransformation(EA.getReplacement());
EA.applyAllSubExprRepl();
return;
}

if (auto TL = getNodeAsType<TypeLoc>(Result, "dim3Type")) {
Expand Down
3 changes: 0 additions & 3 deletions clang/lib/DPCT/ASTTraversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,6 @@ class VectorTypeOperatorRule
};

class ReplaceDim3CtorRule : public NamedMigrationRule<ReplaceDim3CtorRule> {
ReplaceDim3Ctor *getReplaceDim3Modification(
const ast_matchers::MatchFinder::MatchResult &Result);

public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
Expand Down
101 changes: 84 additions & 17 deletions clang/lib/DPCT/ExprAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "DNNAPIMigration.h"
#include "MemberExprRewriter.h"
#include "TypeLocRewriters.h"
#include "Utility.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/Expr.h"
#include "clang/AST/ExprConcepts.h"
Expand Down Expand Up @@ -632,36 +634,101 @@ void ExprAnalysis::analyzeExpr(const CXXTemporaryObjectExpr *Temp) {
analyzeExpr(static_cast<const CXXConstructExpr *>(Temp));
}

const CXXConstructExpr *getConstructExpr(const Expr *E) {
if (const auto *C = dyn_cast_or_null<CXXConstructExpr>(E)) {
return C;
} else if (isa<MaterializeTemporaryExpr>(E)) {
return getConstructExpr(
dyn_cast<MaterializeTemporaryExpr>(E)->getSubExpr());
} else if (isa<CastExpr>(E)) {
return getConstructExpr(dyn_cast<CastExpr>(E)->getSubExpr());
} else {
return nullptr;
}
}

void ExprAnalysis::analyzeExpr(const CXXConstructExpr *Ctor) {
if (Ctor->getConstructor()->getDeclName().getAsString() == "dim3") {
// strip the top CXXConstructExpr, if there's a CXXConstructExpr further
// down
if (Ctor->getNumArgs() == 1) {
Ctor = getConstructExpr(Ctor->getArg(0));
}
if (Ctor == nullptr) {
return;
}
auto Parents = dpct::DpctGlobalInfo::getContext().getParents(*Ctor);
auto IsDecl =
Parents.size() == 1 && (Parents[0].get<VarDecl>() != nullptr ||
Parents[0].get<ExprWithCleanups>() != nullptr);
std::string ArgsString;
llvm::raw_string_ostream OS(ArgsString);
DpctGlobalInfo::printCtadClass(OS, MapNames::getClNamespace() + "range", 3)
<< "(";
if (!IsDecl)
DpctGlobalInfo::printCtadClass(OS, MapNames::getClNamespace() + "range",
3)
<< "(";
ArgumentAnalysis A;
std::string ArgStr = "";
for (auto Arg : Ctor->arguments()) {
A.analyze(Arg);
ArgStr = ", " + A.getReplacedString() + ArgStr;
}
ArgStr.replace(0, 2, "");
OS << ArgStr << ")";
OS << ArgStr;
if (!IsDecl)
OS << ")";
OS.flush();

// Special handling for implicit ctor.
// #define GET_BLOCKS(a) a
// dim3 A = GET_BLOCKS(1);
// Result if using SM.getExpansionRange:
// sycl::range<3> A = sycl::range<3>(1, 1, GET_BLOCKS(1));
// Result if using addReplacement(E):
// #define GET_BLOCKS(a) sycl::range<3>(1, 1, a)
// sycl::range<3> A = GET_BLOCKS(1);
if (Ctor->getParenOrBraceRange().isInvalid() && isOuterMostMacro(Ctor)) {
return addReplacement(
SM.getExpansionRange(Ctor->getBeginLoc()).getBegin(),
SM.getExpansionRange(Ctor->getEndLoc()).getEnd(), ArgsString);
}
addReplacement(Ctor, ArgsString);
CharSourceRange CSR;
if (IsDecl) {
SourceRange SR = Ctor->getParenOrBraceRange();
if (SR.isInvalid()) {
// convert to spelling location if the dim3 constructor is in a macro
// otherwise, Lexer::getLocForEndOfToken returns invalid source location
auto CtorLoc = Ctor->getLocation().isMacroID()
? SM.getSpellingLoc(Ctor->getLocation())
: Ctor->getLocation();
// dim3 a;
// MACRO(... dim3 a; ...)
auto CtorEndLoc = Lexer::getLocForEndOfToken(
CtorLoc, 0, SM, DpctGlobalInfo::getContext().getLangOpts());
CSR = CharSourceRange(SourceRange(CtorEndLoc, CtorEndLoc), false);
ArgsString = "(" + ArgsString + ")";
} else {
SourceRange SR1 =
SourceRange(SR.getBegin().getLocWithOffset(1), SR.getEnd());
CSR = CharSourceRange(SR1, false);
}
} else {
// adjust the statement to replace if top-level constructor includes the
// variable being defined
const Stmt *S = Ctor;
if (!S) {
return;
}
if (S->getBeginLoc().isMacroID() && !isOuterMostMacro(S)) {
auto Range = getDefinitionRange(S->getBeginLoc(), S->getEndLoc());
auto Begin = Range.getBegin();
auto End = Range.getEnd();
End = End.getLocWithOffset(Lexer::MeasureTokenLength(
End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts()));
CSR = CharSourceRange::getTokenRange(Begin, End);
} else {
// Use getStmtExpansionSourceRange(S) to support cases like
// dim3 a = MACRO;
auto Range = getStmtExpansionSourceRange(S);
auto Begin = Range.getBegin();
auto End = Range.getEnd();
CSR = CharSourceRange::getTokenRange(
Begin,
End.getLocWithOffset(Lexer::MeasureTokenLength(
End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts())));
}
}
auto Range = getDefinitionRange(CSR.getBegin(), CSR.getEnd());
auto Length = SM.getDecomposedLoc(Range.getEnd()).second -
SM.getDecomposedLoc(Range.getBegin()).second;
addReplacement(Range.getBegin(), Length, ArgsString);
return;
}
for (auto It = Ctor->arg_begin(); It != Ctor->arg_end(); It++) {
Expand Down
137 changes: 0 additions & 137 deletions clang/lib/DPCT/TextModification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,135 +523,6 @@ ReplaceInclude::getReplacement(const ASTContext &Context) const {
this);
}

void ReplaceDim3Ctor::setRange() {
auto &SM = DpctGlobalInfo::getSourceManager();
if (isDecl) {
SourceRange SR = Ctor->getParenOrBraceRange();
if (SR.isInvalid()) {
// convert to spelling location if the dim3 constructor is in a macro
// otherwise, Lexer::getLocForEndOfToken returns invalid source location
auto CtorLoc = Ctor->getLocation().isMacroID()
? SM.getSpellingLoc(Ctor->getLocation())
: Ctor->getLocation();
// dim3 a;
// MACRO(... dim3 a; ...)
auto CtorEndLoc = Lexer::getLocForEndOfToken(
CtorLoc, 0, SM, DpctGlobalInfo::getContext().getLangOpts());
CSR = CharSourceRange(SourceRange(CtorEndLoc, CtorEndLoc), false);
} else {
SourceRange SR1 =
SourceRange(SR.getBegin().getLocWithOffset(1), SR.getEnd());
CSR = CharSourceRange(SR1, false);
}
} else {
// adjust the statement to replace if top-level constructor includes the
// variable being defined
const Stmt *S = getReplaceStmt(Ctor);
if (!S) {
return;
}
if (S->getBeginLoc().isMacroID() && !isOuterMostMacro(S)) {
auto Range = getDefinitionRange(S->getBeginLoc(), S->getEndLoc());
auto Begin = Range.getBegin();
auto End = Range.getEnd();
End = End.getLocWithOffset(Lexer::MeasureTokenLength(
End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts()));
CSR = CharSourceRange::getTokenRange(Begin, End);
} else {
// Use getStmtExpansionSourceRange(S) to support cases like
// dim3 a = MACRO;
auto Range = getStmtExpansionSourceRange(S);
auto Begin = Range.getBegin();
auto End = Range.getEnd();
CSR = CharSourceRange::getTokenRange(
Begin,
End.getLocWithOffset(Lexer::MeasureTokenLength(
End, SM, dpct::DpctGlobalInfo::getContext().getLangOpts())));
}
}
}

ReplaceInclude *ReplaceDim3Ctor::getEmpty() {
return new ReplaceInclude(CSR, "");
}

// Strips possible Materialize and Cast operators from CXXConstructor
const CXXConstructExpr *ReplaceDim3Ctor::getConstructExpr(const Expr *E) {
if (auto C = dyn_cast_or_null<CXXConstructExpr>(E)) {
return C;
} else if (isa<MaterializeTemporaryExpr>(E)) {
return getConstructExpr(
dyn_cast<MaterializeTemporaryExpr>(E)->getSubExpr());
} else if (isa<CastExpr>(E)) {
return getConstructExpr(dyn_cast<CastExpr>(E)->getSubExpr());
} else {
return nullptr;
}
}

// Returns the full replacement string for the CXXConstructorExpr
std::string
ReplaceDim3Ctor::getSyclRangeCtor(const CXXConstructExpr *Ctor) const {
ExprAnalysis Analysis(Ctor);
return Analysis.getReplacedString();
}

const Stmt *ReplaceDim3Ctor::getReplaceStmt(const Stmt *S) const {
if (auto Ctor = dyn_cast_or_null<CXXConstructExpr>(S)) {
if (Ctor->getNumArgs() == 1) {
return getConstructExpr(Ctor->getArg(0));
}
}
return S;
}

std::string ReplaceDim3Ctor::getReplaceString() const {
if (isDecl) {
// Get the new parameter list for the replaced constructor, without the
// parens
std::string ReplacedString;
llvm::raw_string_ostream OS(ReplacedString);
ArgumentAnalysis AA;
std::string ArgStr = "";
for (auto Arg : Ctor->arguments()) {
AA.analyze(Arg);
ArgStr = ", " + AA.getReplacedString() + ArgStr;
}
ArgStr.replace(0, 2, "");
OS << ArgStr;
OS.flush();
if (Ctor->getParenOrBraceRange().isInvalid()) {
// dim3 = a;
ReplacedString = "(" + ReplacedString + ")";
}
return ReplacedString;
} else {
std::string S;
if (FinalCtor) {
S = getSyclRangeCtor(FinalCtor);
} else {
S = getSyclRangeCtor(Ctor);
}
return S;
}
}

std::shared_ptr<ExtReplacement>
ReplaceDim3Ctor::getReplacement(const ASTContext &Context) const {
if (this->isIgnoreTM())
return nullptr;
// Use getDefinitionRange in general cases,
// For cases like dim3 a = MACRO;
// CSR is already set to the expansion range.
auto &SM = dpct::DpctGlobalInfo::getSourceManager();
ReplacementString = getReplaceString();
auto Range = getDefinitionRange(CSR.getBegin(), CSR.getEnd());
auto Length = SM.getDecomposedLoc(Range.getEnd()).second -
SM.getDecomposedLoc(Range.getBegin()).second;
return std::make_shared<ExtReplacement>(SM, Range.getBegin(), Length,
getReplaceString(), this);
}

std::shared_ptr<ExtReplacement>
InsertComment::getReplacement(const ASTContext &Context) const {
if (this->isIgnoreTM())
Expand Down Expand Up @@ -930,14 +801,6 @@ void ReplaceInclude::print(llvm::raw_ostream &OS, ASTContext &Context,
printReplacement(OS, T);
}

void ReplaceDim3Ctor::print(llvm::raw_ostream &OS, ASTContext &Context,
const bool PrintDetail) const {
printHeader(OS, getID(), PrintDetail ? getParentRuleName() : StringRef());
printLocation(OS, CSR.getBegin(), Context, PrintDetail);
Ctor->printPretty(OS, nullptr, PrintingPolicy(Context.getLangOpts()));
printReplacement(OS, ReplacementString);
}

void InsertComment::print(llvm::raw_ostream &OS, ASTContext &Context,
const bool PrintDetail) const {
printHeader(OS, getID(), PrintDetail ? getParentRuleName() : StringRef());
Expand Down
Loading

0 comments on commit f54689d

Please sign in to comment.