Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Skip lambda in host function scope #2518

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4675,8 +4675,10 @@ void DeviceFunctionDeclRule::runRule(
if (FD->isTemplateInstantiation())
return;

if (FD->hasAttr<CUDADeviceAttr>() &&
FD->getAttr<CUDADeviceAttr>()->isImplicit())
// We need skip lambda in host code, but cannot skip lambda in device code.
if (const FunctionDecl *OuterMostFD = findTheOuterMostFunctionDecl(FD);
OuterMostFD && (!OuterMostFD->hasAttr<CUDADeviceAttr>() &&
!OuterMostFD->hasAttr<CUDAGlobalAttr>()))
return;

const auto &FTL = FD->getFunctionTypeLoc();
Expand Down Expand Up @@ -4711,8 +4713,10 @@ void DeviceFunctionDeclRule::runRule(
DpctGlobalInfo::getRunRound() == 1))
return;

if (FD->hasAttr<CUDADeviceAttr>() &&
FD->getAttr<CUDADeviceAttr>()->isImplicit())
// We need skip lambda in host code, but cannot skip lambda in device code.
if (const FunctionDecl *OuterMostFD = findTheOuterMostFunctionDecl(FD);
OuterMostFD && (!OuterMostFD->hasAttr<CUDADeviceAttr>() &&
!OuterMostFD->hasAttr<CUDAGlobalAttr>()))
return;

if (FD->isVariadic()) {
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/DPCT/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "RulesMathLib/MapNamesRandom.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTTypeTraits.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/Expr.h"
#include "clang/AST/ExprCXX.h"
#include "clang/Basic/SourceLocation.h"
Expand Down Expand Up @@ -2708,6 +2709,19 @@ findTheOuterMostCompoundStmtUntilMeetControlFlowNodes(const CallExpr *CE) {
return LatestCS;
}

const FunctionDecl *findTheOuterMostFunctionDecl(const clang::Decl *D) {
if (!D)
return nullptr;
const FunctionDecl *FD = nullptr;
const DeclContext *Ctx = D->getDeclContext();
while (Ctx) {
if (Ctx->getDeclKind() == Decl::Function)
FD = dyn_cast<FunctionDecl>(Ctx);
Ctx = Ctx->getParent();
}
return FD;
}

bool isInMacroDefinition(SourceLocation BeginLoc, SourceLocation EndLoc) {
auto Range = getDefinitionRange(BeginLoc, EndLoc);
auto ItBegin = dpct::DpctGlobalInfo::getExpansionRangeToMacroRecord().find(
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ const clang::NamedDecl *getNamedDecl(const clang::Type *TypePtr);
const clang::LambdaExpr *
getImmediateOuterLambdaExpr(const clang::FunctionDecl *FuncDecl);
const DeclRefExpr *getAddressedRef(const Expr *E);
const clang::FunctionDecl *findTheOuterMostFunctionDecl(const clang::Decl *D);

// Source Range & location, offset.
clang::SourceRange getScopeInsertRange(const clang::MemberExpr *ME);
Expand Down
48 changes: 48 additions & 0 deletions clang/test/dpct/kernel_without_name-usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,51 @@ template <class T> __global__ void foo_kernel1(const T *a) {
__shared__ T shmem[100];
}
#undef DISPATCH

// CHECK: void foo_kernel2(uint8_t *dpct_local) {
// CHECK-NEXT: auto smem = (int *)dpct_local;
// CHECK-NEXT: char *out_cached = reinterpret_cast<char *>(smem);
// CHECK-NEXT: }
__global__ void foo_kernel2() {
extern __shared__ int smem[];
char *out_cached = reinterpret_cast<char *>(smem);
}

// CHECK: void run_foo2() {
// CHECK-NEXT: [&] {
// CHECK-NEXT: dpct::get_in_order_queue().submit(
// CHECK-NEXT: [&](sycl::handler &cgh) {
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(256), cgh);
// CHECK-EMPTY:
// CHECK-NEXT: cgh.parallel_for(
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 1), sycl::range<3>(1, 1, 1)),
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
// CHECK-NEXT: foo_kernel2(dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
// CHECK-NEXT: });
// CHECK-NEXT: });
// CHECK-NEXT: }();
// CHECK-NEXT: }
void run_foo2() {
[&] {
foo_kernel2<<<1, 1, 256>>>();
}();
}

// CHECK: void foo_kernel3(const sycl::nd_item<3> &item_ct1) {
// CHECK-NEXT: auto lambda1 = [&](const sycl::nd_item<3> &item_ct1) {
// CHECK-NEXT: item_ct1.get_local_id(2);
// CHECK-NEXT: };
// CHECK-NEXT: auto lambda2 = [&](const sycl::nd_item<3> &item_ct1) {
// CHECK-NEXT: lambda1(item_ct1);
// CHECK-NEXT: };
// CHECK-NEXT: lambda2(item_ct1);
// CHECK-NEXT: }
__global__ void foo_kernel3() {
auto lambda1 = [&]() {
threadIdx.x;
};
auto lambda2 = [&]() {
lambda1();
};
lambda2();
}