Skip to content

Commit

Permalink
[SYCLomatic] Skip lambda in host function scope (#2518)
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 authored Nov 29, 2024
1 parent 644f264 commit 3ecb185
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 4 deletions.
12 changes: 8 additions & 4 deletions clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4675,8 +4675,10 @@ void DeviceFunctionDeclRule::runRule(
if (FD->isTemplateInstantiation())
return;

if (FD->hasAttr<CUDADeviceAttr>() &&
FD->getAttr<CUDADeviceAttr>()->isImplicit())
// We need skip lambda in host code, but cannot skip lambda in device code.
if (const FunctionDecl *OuterMostFD = findTheOuterMostFunctionDecl(FD);
OuterMostFD && (!OuterMostFD->hasAttr<CUDADeviceAttr>() &&
!OuterMostFD->hasAttr<CUDAGlobalAttr>()))
return;

const auto &FTL = FD->getFunctionTypeLoc();
Expand Down Expand Up @@ -4711,8 +4713,10 @@ void DeviceFunctionDeclRule::runRule(
DpctGlobalInfo::getRunRound() == 1))
return;

if (FD->hasAttr<CUDADeviceAttr>() &&
FD->getAttr<CUDADeviceAttr>()->isImplicit())
// We need skip lambda in host code, but cannot skip lambda in device code.
if (const FunctionDecl *OuterMostFD = findTheOuterMostFunctionDecl(FD);
OuterMostFD && (!OuterMostFD->hasAttr<CUDADeviceAttr>() &&
!OuterMostFD->hasAttr<CUDAGlobalAttr>()))
return;

if (FD->isVariadic()) {
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/DPCT/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "RulesMathLib/MapNamesRandom.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTTypeTraits.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/Expr.h"
#include "clang/AST/ExprCXX.h"
#include "clang/Basic/SourceLocation.h"
Expand Down Expand Up @@ -2708,6 +2709,19 @@ findTheOuterMostCompoundStmtUntilMeetControlFlowNodes(const CallExpr *CE) {
return LatestCS;
}

const FunctionDecl *findTheOuterMostFunctionDecl(const clang::Decl *D) {
if (!D)
return nullptr;
const FunctionDecl *FD = nullptr;
const DeclContext *Ctx = D->getDeclContext();
while (Ctx) {
if (Ctx->getDeclKind() == Decl::Function)
FD = dyn_cast<FunctionDecl>(Ctx);
Ctx = Ctx->getParent();
}
return FD;
}

bool isInMacroDefinition(SourceLocation BeginLoc, SourceLocation EndLoc) {
auto Range = getDefinitionRange(BeginLoc, EndLoc);
auto ItBegin = dpct::DpctGlobalInfo::getExpansionRangeToMacroRecord().find(
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ const clang::NamedDecl *getNamedDecl(const clang::Type *TypePtr);
const clang::LambdaExpr *
getImmediateOuterLambdaExpr(const clang::FunctionDecl *FuncDecl);
const DeclRefExpr *getAddressedRef(const Expr *E);
const clang::FunctionDecl *findTheOuterMostFunctionDecl(const clang::Decl *D);

// Source Range & location, offset.
clang::SourceRange getScopeInsertRange(const clang::MemberExpr *ME);
Expand Down
48 changes: 48 additions & 0 deletions clang/test/dpct/kernel_without_name-usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,51 @@ template <class T> __global__ void foo_kernel1(const T *a) {
__shared__ T shmem[100];
}
#undef DISPATCH

// CHECK: void foo_kernel2(uint8_t *dpct_local) {
// CHECK-NEXT: auto smem = (int *)dpct_local;
// CHECK-NEXT: char *out_cached = reinterpret_cast<char *>(smem);
// CHECK-NEXT: }
__global__ void foo_kernel2() {
extern __shared__ int smem[];
char *out_cached = reinterpret_cast<char *>(smem);
}

// CHECK: void run_foo2() {
// CHECK-NEXT: [&] {
// CHECK-NEXT: dpct::get_in_order_queue().submit(
// CHECK-NEXT: [&](sycl::handler &cgh) {
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(256), cgh);
// CHECK-EMPTY:
// CHECK-NEXT: cgh.parallel_for(
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 1), sycl::range<3>(1, 1, 1)),
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
// CHECK-NEXT: foo_kernel2(dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
// CHECK-NEXT: });
// CHECK-NEXT: });
// CHECK-NEXT: }();
// CHECK-NEXT: }
void run_foo2() {
[&] {
foo_kernel2<<<1, 1, 256>>>();
}();
}

// CHECK: void foo_kernel3(const sycl::nd_item<3> &item_ct1) {
// CHECK-NEXT: auto lambda1 = [&](const sycl::nd_item<3> &item_ct1) {
// CHECK-NEXT: item_ct1.get_local_id(2);
// CHECK-NEXT: };
// CHECK-NEXT: auto lambda2 = [&](const sycl::nd_item<3> &item_ct1) {
// CHECK-NEXT: lambda1(item_ct1);
// CHECK-NEXT: };
// CHECK-NEXT: lambda2(item_ct1);
// CHECK-NEXT: }
__global__ void foo_kernel3() {
auto lambda1 = [&]() {
threadIdx.x;
};
auto lambda2 = [&]() {
lambda1();
};
lambda2();
}

0 comments on commit 3ecb185

Please sign in to comment.