diff --git a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp index a45e4234ca681..da350c2952241 100644 --- a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp +++ b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp @@ -23,6 +23,7 @@ // similar approach. #include +#include #ifdef _WIN32 @@ -43,22 +44,6 @@ // ------------------------------------ -static constexpr const char *DirSep = "\\"; - -// cribbed from sycl/source/detail/os_util.cpp -std::string getDirName(const char *Path) { - std::string Tmp(Path); - // Remove trailing directory separators - Tmp.erase(Tmp.find_last_not_of("/\\") + 1, std::string::npos); - - size_t pos = Tmp.find_last_of("/\\"); - if (pos != std::string::npos) - return Tmp.substr(0, pos); - - // If no directory separator is present return initial path like dirname does - return Tmp; -} - // cribbed from sycl/source/detail/os_util.cpp // TODO: Just inline it. using OSModuleHandle = intptr_t; @@ -80,20 +65,18 @@ static OSModuleHandle getOSModuleHandle(const void *VirtAddr) { // cribbed from sycl/source/detail/os_util.cpp /// Returns an absolute path where the object was found. -std::string getCurrentDSODir() { - char Path[MAX_PATH]; - Path[0] = '\0'; - Path[sizeof(Path) - 1] = '\0'; +std::wstring getCurrentDSODir() { + wchar_t Path[MAX_PATH]; auto Handle = getOSModuleHandle(reinterpret_cast(&getCurrentDSODir)); - DWORD Ret = GetModuleFileNameA( - reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), - reinterpret_cast(&Path), sizeof(Path)); + DWORD Ret = GetModuleFileName( + reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), Path, + sizeof(Path)); assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?"); - assert(Ret > 0 && "GetModuleFileNameA failed"); + assert(Ret > 0 && "GetModuleFileName failed"); (void)Ret; - BOOL RetCode = PathRemoveFileSpecA(reinterpret_cast(&Path)); - assert(RetCode && "PathRemoveFileSpecA failed"); + BOOL RetCode = PathRemoveFileSpec(Path); + assert(RetCode && "PathRemoveFileSpec failed"); (void)RetCode; return Path; @@ -121,7 +104,7 @@ std::string getCurrentDSODir() { // ------------------------------------ -using MapT = std::map; +using MapT = std::map; MapT &getDllMap() { static MapT dllMap; @@ -141,47 +124,52 @@ void preloadLibraries() { // UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); // Exclude current directory from DLL search path - if (!SetDllDirectoryA("")) { + if (!SetDllDirectory(L"")) { assert(false && "Failed to update DLL search path"); } // this path duplicates sycl/detail/pi.cpp:initializePlugins - const std::string LibSYCLDir = getCurrentDSODir() + DirSep; + std::filesystem::path LibSYCLDir(getCurrentDSODir()); MapT &dllMap = getDllMap(); - std::string ocl_path = LibSYCLDir + __SYCL_OPENCL_PLUGIN_NAME; - dllMap.emplace(ocl_path, LoadLibraryExA(ocl_path.c_str(), NULL, NULL)); + auto ocl_path = LibSYCLDir / __SYCL_OPENCL_PLUGIN_NAME; + dllMap.emplace(ocl_path, + LoadLibraryEx(ocl_path.wstring().c_str(), NULL, NULL)); - std::string l0_path = LibSYCLDir + __SYCL_LEVEL_ZERO_PLUGIN_NAME; - dllMap.emplace(l0_path, LoadLibraryExA(l0_path.c_str(), NULL, NULL)); + auto l0_path = LibSYCLDir / __SYCL_LEVEL_ZERO_PLUGIN_NAME; + dllMap.emplace(l0_path, LoadLibraryEx(l0_path.wstring().c_str(), NULL, NULL)); - std::string cuda_path = LibSYCLDir + __SYCL_CUDA_PLUGIN_NAME; - dllMap.emplace(cuda_path, LoadLibraryExA(cuda_path.c_str(), NULL, NULL)); + auto cuda_path = LibSYCLDir / __SYCL_CUDA_PLUGIN_NAME; + dllMap.emplace(cuda_path, + LoadLibraryEx(cuda_path.wstring().c_str(), NULL, NULL)); - std::string esimd_path = LibSYCLDir + __SYCL_ESIMD_EMULATOR_PLUGIN_NAME; - dllMap.emplace(esimd_path, LoadLibraryExA(esimd_path.c_str(), NULL, NULL)); + auto esimd_path = LibSYCLDir / __SYCL_ESIMD_EMULATOR_PLUGIN_NAME; + dllMap.emplace(esimd_path, + LoadLibraryEx(esimd_path.wstring().c_str(), NULL, NULL)); - std::string hip_path = LibSYCLDir + __SYCL_HIP_PLUGIN_NAME; - dllMap.emplace(hip_path, LoadLibraryExA(hip_path.c_str(), NULL, NULL)); + auto hip_path = LibSYCLDir / __SYCL_HIP_PLUGIN_NAME; + dllMap.emplace(hip_path, + LoadLibraryEx(hip_path.wstring().c_str(), NULL, NULL)); - std::string ur_path = LibSYCLDir + __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME; - dllMap.emplace(ur_path, LoadLibraryExA(ur_path.c_str(), NULL, NULL)); + auto ur_path = LibSYCLDir / __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME; + dllMap.emplace(ur_path, LoadLibraryEx(ur_path.wstring().c_str(), NULL, NULL)); - std::string nativecpu_path = LibSYCLDir + __SYCL_NATIVE_CPU_PLUGIN_NAME; + auto nativecpu_path = LibSYCLDir / __SYCL_NATIVE_CPU_PLUGIN_NAME; dllMap.emplace(nativecpu_path, - LoadLibraryExA(nativecpu_path.c_str(), NULL, NULL)); + LoadLibraryEx(nativecpu_path.wstring().c_str(), NULL, NULL)); // Restore system error handling. (void)SetErrorMode(SavedMode); - if (!SetDllDirectoryA(nullptr)) { + if (!SetDllDirectory(nullptr)) { assert(false && "Failed to restore DLL search path"); } } /// windows_pi.cpp:loadOsPluginLibrary() calls this to get the DLL loaded /// earlier. -__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { +__declspec(dllexport) void *getPreloadedPlugin( + const std::filesystem::path &PluginPath) { MapT &dllMap = getDllMap(); @@ -189,7 +177,7 @@ __declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { // which is perfectly valid. if (match == dllMap.end()) { // unit testing? return nullptr (not found) rather than risk asserting below - if (PluginPath.find("unittests") != std::string::npos) + if (PluginPath.string().find("unittests") != std::string::npos) return nullptr; // Otherwise, asking for something we don't know about at all, is an issue. @@ -200,6 +188,10 @@ __declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { return match->second; } +__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { + return getPreloadedPlugin(std::filesystem::path(PluginPath)); +} + BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module DWORD fdwReason, // reason for calling function LPVOID lpReserved) // reserved diff --git a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp index c1104a6d26c77..0c60d03e72433 100644 --- a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp +++ b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp @@ -9,7 +9,11 @@ #pragma once #ifdef _WIN32 +#include #include +__declspec(dllexport) void *getPreloadedPlugin( + const std::filesystem::path &PluginPath); +// TODO: Remove this version during ABI breakage window __declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath); #endif diff --git a/sycl/source/detail/os_util.cpp b/sycl/source/detail/os_util.cpp index dde34762843f8..e9dd8078632a5 100644 --- a/sycl/source/detail/os_util.cpp +++ b/sycl/source/detail/os_util.cpp @@ -29,6 +29,8 @@ #elif defined(__SYCL_RT_OS_WINDOWS) +#include + #include #include #include @@ -139,23 +141,6 @@ std::string OSUtil::getDirName(const char *Path) { } #elif defined(__SYCL_RT_OS_WINDOWS) -// TODO: Just inline it. -using OSModuleHandle = intptr_t; -static constexpr OSModuleHandle ExeModuleHandle = -1; -static OSModuleHandle getOSModuleHandle(const void *VirtAddr) { - HMODULE PhModule; - DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | - GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT; - auto LpModuleAddr = reinterpret_cast(VirtAddr); - if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) { - // Expect the caller to check for zero and take - // necessary action - return 0; - } - if (PhModule == GetModuleHandleA(nullptr)) - return ExeModuleHandle; - return reinterpret_cast(PhModule); -} /// Returns an absolute path where the object was found. // pi_win_proxy_loader.dll uses this same logic. If it is changed diff --git a/sycl/source/detail/pi.cpp b/sycl/source/detail/pi.cpp index 33dfdaf005e41..047d6aa1e3bdb 100644 --- a/sycl/source/detail/pi.cpp +++ b/sycl/source/detail/pi.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #ifdef XPTI_ENABLE_INSTRUMENTATION // Include the headers necessary for emitting @@ -435,29 +436,35 @@ std::vector &initialize() { return GlobalHandler::instance().getPlugins(); } +// Implementation of this function is OS specific. Please see windows_pi.cpp and +// posix_pi.cpp. +// TODO: refactor code when support matrix for DPCPP changes and is +// available on all supported systems. +std::vector> +loadPlugins(const std::vector> &&PluginNames); + static void initializePlugins(std::vector &Plugins) { - std::vector> PluginNames = findPlugins(); + const std::vector> PluginNames = + findPlugins(); if (PluginNames.empty() && trace(PI_TRACE_ALL)) std::cerr << "SYCL_PI_TRACE[all]: " << "No Plugins Found." << std::endl; - const std::string LibSYCLDir = - sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep; + // Get library handles for the list of plugins. + std::vector> LoadedPlugins = + loadPlugins(std::move(PluginNames)); - for (unsigned int I = 0; I < PluginNames.size(); I++) { + for (auto [Name, Backend, Library] : LoadedPlugins) { std::shared_ptr PluginInformation = std::make_shared( PiPlugin{_PI_H_VERSION_STRING, _PI_H_VERSION_STRING, /*Targets=*/nullptr, /*FunctionPointers=*/{}}); - void *Library = loadPlugin(LibSYCLDir + PluginNames[I].first); - if (!Library) { if (trace(PI_TRACE_ALL)) { std::cerr << "SYCL_PI_TRACE[all]: " << "Check if plugin is present. " - << "Failed to load plugin: " << PluginNames[I].first - << std::endl; + << "Failed to load plugin: " << Name << std::endl; } continue; } @@ -465,17 +472,17 @@ static void initializePlugins(std::vector &Plugins) { if (!bindPlugin(Library, PluginInformation)) { if (trace(PI_TRACE_ALL)) { std::cerr << "SYCL_PI_TRACE[all]: " - << "Failed to bind PI APIs to the plugin: " - << PluginNames[I].first << std::endl; + << "Failed to bind PI APIs to the plugin: " << Name + << std::endl; } continue; } - PluginPtr &NewPlugin = Plugins.emplace_back(std::make_shared( - PluginInformation, PluginNames[I].second, Library)); + PluginPtr &NewPlugin = Plugins.emplace_back( + std::make_shared(PluginInformation, Backend, Library)); if (trace(TraceLevel::PI_TRACE_BASIC)) std::cerr << "SYCL_PI_TRACE[basic]: " - << "Plugin found and successfully loaded: " - << PluginNames[I].first << " [ PluginVersion: " + << "Plugin found and successfully loaded: " << Name + << " [ PluginVersion: " << NewPlugin->getPiPlugin().PluginVersion << " ]" << std::endl; } diff --git a/sycl/source/detail/posix_pi.cpp b/sycl/source/detail/posix_pi.cpp index e72f4d8b0af2f..220727f3bb59a 100644 --- a/sycl/source/detail/posix_pi.cpp +++ b/sycl/source/detail/posix_pi.cpp @@ -48,6 +48,22 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { return dlsym(Library, FunctionName.c_str()); } +// Load plugins corresponding to provided list of plugin names. +std::vector> +loadPlugins(const std::vector> &&PluginNames) { + std::vector> LoadedPlugins; + const std::string LibSYCLDir = + sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep; + + for (auto &PluginName : PluginNames) { + void *Library = loadOsPluginLibrary(LibSYCLDir + PluginName.first); + LoadedPlugins.push_back(std::make_tuple( + std::move(PluginName.first), std::move(PluginName.second), Library)); + } + + return LoadedPlugins; +} + } // namespace detail::pi } // namespace _V1 } // namespace sycl diff --git a/sycl/source/detail/windows_os_utils.hpp b/sycl/source/detail/windows_os_utils.hpp new file mode 100644 index 0000000000000..690fbba46371c --- /dev/null +++ b/sycl/source/detail/windows_os_utils.hpp @@ -0,0 +1,28 @@ +//==-- windows_os_utils.hpp - Header file with common utils for Windows --==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +using OSModuleHandle = intptr_t; +constexpr OSModuleHandle ExeModuleHandle = -1; +inline OSModuleHandle getOSModuleHandle(const void *VirtAddr) { + HMODULE PhModule; + DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT; + auto LpModuleAddr = reinterpret_cast(VirtAddr); + if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) { + // Expect the caller to check for zero and take + // necessary action + return 0; + } + if (PhModule == GetModuleHandleA(nullptr)) + return ExeModuleHandle; + return reinterpret_cast(PhModule); +} diff --git a/sycl/source/detail/windows_pi.cpp b/sycl/source/detail/windows_pi.cpp index 18ab05a52aa66..05ace8ff63863 100644 --- a/sycl/source/detail/windows_pi.cpp +++ b/sycl/source/detail/windows_pi.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include @@ -13,6 +14,7 @@ #include #include +#include "detail/windows_os_utils.hpp" #include "pi_win_proxy_loader.hpp" namespace sycl { @@ -66,6 +68,39 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { GetProcAddress((HMODULE)Library, FunctionName.c_str())); } +static std::filesystem::path getCurrentDSODirPath() { + wchar_t Path[MAX_PATH]; + auto Handle = + getOSModuleHandle(reinterpret_cast(&getCurrentDSODirPath)); + DWORD Ret = GetModuleFileName( + reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), Path, + sizeof(Path)); + assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?"); + assert(Ret > 0 && "GetModuleFileName failed"); + (void)Ret; + + BOOL RetCode = PathRemoveFileSpec(Path); + assert(RetCode && "PathRemoveFileSpec failed"); + (void)RetCode; + + return std::filesystem::path(Path); +} + +// Load plugins corresponding to provided list of plugin names. +std::vector> +loadPlugins(const std::vector> &&PluginNames) { + std::vector> LoadedPlugins; + const std::filesystem::path LibSYCLDir = getCurrentDSODirPath(); + + for (auto &PluginName : PluginNames) { + void *Library = getPreloadedPlugin(LibSYCLDir / PluginName.first); + LoadedPlugins.push_back(std::make_tuple( + std::move(PluginName.first), std::move(PluginName.second), Library)); + } + + return LoadedPlugins; +} + } // namespace pi } // namespace detail } // namespace _V1