diff --git a/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h b/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h index 0d88a3ea482b..eaa0cdcc93cd 100644 --- a/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h +++ b/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h @@ -818,6 +818,35 @@ inline std::function getDerefedType(size_t Idx) }; } +inline std::function +getReplacedTypeNameForDerefExpr(size_t Idx) { + return [=](const CallExpr *C) -> std::string { + if (Idx >= C->getNumArgs()) + return ""; + + const auto *ArgExpr = C->getArg(Idx); + QualType ArgExprType = ArgExpr->getType(); + + if (auto *CSCE = + dyn_cast(ArgExpr->IgnoreImplicitAsWritten())) + ArgExprType = CSCE->getTypeAsWritten(); + + while (const auto *ET = dyn_cast(ArgExprType)) { + ArgExprType = ET->getNamedType(); + if (const auto *TDT = dyn_cast(ArgExprType)) { + if (isRedeclInCUDAHeader(TDT)) + break; + ArgExprType = TDT->getDecl()->getUnderlyingType(); + } + } + + ArgExprType = DerefQualType(ArgExprType); + return ArgExprType.isNull() + ? "" + : DpctGlobalInfo::getReplacedTypeName(ArgExprType); + }; +} + inline std::function getTemplateArg(size_t Idx) { return [=](const CallExpr *C) -> std::string { std::string TemplateArgStr = ""; diff --git a/clang/lib/DPCT/RulesLang/APINamesMemory.inc b/clang/lib/DPCT/RulesLang/APINamesMemory.inc index 38099432ab9b..9901c17a6cc2 100644 --- a/clang/lib/DPCT/RulesLang/APINamesMemory.inc +++ b/clang/lib/DPCT/RulesLang/APINamesMemory.inc @@ -217,7 +217,8 @@ ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cudaHostGetDevicePointer", ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cuMemHostGetDevicePointer_v2", DEREF(makeCallArgCreatorWithCall(0)), - CAST(getDerefedType(0), ARG(1)))) + CAST(getReplacedTypeNameForDerefExpr(0), + ARG(1)))) ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY( checkIsUSM(), diff --git a/clang/test/dpct/USM-restricted.cu b/clang/test/dpct/USM-restricted.cu index eaa19966d85f..47f818efc5e9 100644 --- a/clang/test/dpct/USM-restricted.cu +++ b/clang/test/dpct/USM-restricted.cu @@ -378,6 +378,10 @@ void foo() { // CHECK: MY_SAFE_CALL(DPCT_CHECK_ERROR(*D_ptr = (dpct::device_ptr)h_A)); MY_SAFE_CALL(cuMemHostGetDevicePointer(D_ptr, h_A, 0)); + unsigned long long addr; + // CHECK: *(dpct::device_ptr *)&addr = (dpct::device_ptr)h_A; + cuMemHostGetDevicePointer((CUdeviceptr *)&addr, h_A, 0); + cudaHostRegister(h_A, size, 0); // CHECK: errorCode = 0; errorCode = cudaHostRegister(h_A, size, 0);