Skip to content

Commit

Permalink
[SYCL] Don't use legacy ANSI-only Windows API for loading plugins (in…
Browse files Browse the repository at this point in the history
…tel#10943)

Currently to load PI plugins we use legacy ANSI-only versions of Windows
API like GetModuleFileNameA, PathRemoveFileSpecA etc. Problem is that if
path containing PI plugins has any non-ANSI symbols then PI plugins are
not found and not loaded.

In this patch get rid of legacy API calls, for example, use
GetModuleFileName instead of GetModuleFileNameA. GetModuleFileName is an
alias which automatically selects the ANSI or Unicode version of this
function.

Another difference is that GetModuleFileName and other similar aliases
work with wchar_t to be able to handle unicode on Windows (in contrast
to legacy GetModuleFileNameA which works with char_t).
So, use std::filesystem:path to work with library paths for convenience
(instead of storing path in std::string or std::wstring) because it
allows to handle paths without caring about format, can be constructed
from string/wstring/.. and can be converted to string/wstring ...
DPCPP is supported on some linux systems where default compiler is gcc
7.5 which doesn't provide `<filesystem>` support. On Windows, minimal
supported version of Visual Studio is 2019 where `<filesystem`> is
available (supported since Visual Studio 2017 version 15.7).
That's why use filesystem::path only on Windows for now, added TODO to
do the same on Linux when matrix support changes.
  • Loading branch information
againull authored Aug 25, 2023
1 parent 006b882 commit 5c30815
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 77 deletions.
84 changes: 38 additions & 46 deletions sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
// similar approach.

#include <cassert>
#include <filesystem>

#ifdef _WIN32

Expand All @@ -43,22 +44,6 @@

// ------------------------------------

static constexpr const char *DirSep = "\\";

// cribbed from sycl/source/detail/os_util.cpp
std::string getDirName(const char *Path) {
std::string Tmp(Path);
// Remove trailing directory separators
Tmp.erase(Tmp.find_last_not_of("/\\") + 1, std::string::npos);

size_t pos = Tmp.find_last_of("/\\");
if (pos != std::string::npos)
return Tmp.substr(0, pos);

// If no directory separator is present return initial path like dirname does
return Tmp;
}

// cribbed from sycl/source/detail/os_util.cpp
// TODO: Just inline it.
using OSModuleHandle = intptr_t;
Expand All @@ -80,20 +65,18 @@ static OSModuleHandle getOSModuleHandle(const void *VirtAddr) {

// cribbed from sycl/source/detail/os_util.cpp
/// Returns an absolute path where the object was found.
std::string getCurrentDSODir() {
char Path[MAX_PATH];
Path[0] = '\0';
Path[sizeof(Path) - 1] = '\0';
std::wstring getCurrentDSODir() {
wchar_t Path[MAX_PATH];
auto Handle = getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODir));
DWORD Ret = GetModuleFileNameA(
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle),
reinterpret_cast<LPSTR>(&Path), sizeof(Path));
DWORD Ret = GetModuleFileName(
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle), Path,
sizeof(Path));
assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?");
assert(Ret > 0 && "GetModuleFileNameA failed");
assert(Ret > 0 && "GetModuleFileName failed");
(void)Ret;

BOOL RetCode = PathRemoveFileSpecA(reinterpret_cast<LPSTR>(&Path));
assert(RetCode && "PathRemoveFileSpecA failed");
BOOL RetCode = PathRemoveFileSpec(Path);
assert(RetCode && "PathRemoveFileSpec failed");
(void)RetCode;

return Path;
Expand Down Expand Up @@ -121,7 +104,7 @@ std::string getCurrentDSODir() {

// ------------------------------------

using MapT = std::map<std::string, void *>;
using MapT = std::map<std::filesystem::path, void *>;

MapT &getDllMap() {
static MapT dllMap;
Expand All @@ -141,55 +124,60 @@ void preloadLibraries() {
//
UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS);
// Exclude current directory from DLL search path
if (!SetDllDirectoryA("")) {
if (!SetDllDirectory(L"")) {
assert(false && "Failed to update DLL search path");
}

// this path duplicates sycl/detail/pi.cpp:initializePlugins
const std::string LibSYCLDir = getCurrentDSODir() + DirSep;
std::filesystem::path LibSYCLDir(getCurrentDSODir());

MapT &dllMap = getDllMap();

std::string ocl_path = LibSYCLDir + __SYCL_OPENCL_PLUGIN_NAME;
dllMap.emplace(ocl_path, LoadLibraryExA(ocl_path.c_str(), NULL, NULL));
auto ocl_path = LibSYCLDir / __SYCL_OPENCL_PLUGIN_NAME;
dllMap.emplace(ocl_path,
LoadLibraryEx(ocl_path.wstring().c_str(), NULL, NULL));

std::string l0_path = LibSYCLDir + __SYCL_LEVEL_ZERO_PLUGIN_NAME;
dllMap.emplace(l0_path, LoadLibraryExA(l0_path.c_str(), NULL, NULL));
auto l0_path = LibSYCLDir / __SYCL_LEVEL_ZERO_PLUGIN_NAME;
dllMap.emplace(l0_path, LoadLibraryEx(l0_path.wstring().c_str(), NULL, NULL));

std::string cuda_path = LibSYCLDir + __SYCL_CUDA_PLUGIN_NAME;
dllMap.emplace(cuda_path, LoadLibraryExA(cuda_path.c_str(), NULL, NULL));
auto cuda_path = LibSYCLDir / __SYCL_CUDA_PLUGIN_NAME;
dllMap.emplace(cuda_path,
LoadLibraryEx(cuda_path.wstring().c_str(), NULL, NULL));

std::string esimd_path = LibSYCLDir + __SYCL_ESIMD_EMULATOR_PLUGIN_NAME;
dllMap.emplace(esimd_path, LoadLibraryExA(esimd_path.c_str(), NULL, NULL));
auto esimd_path = LibSYCLDir / __SYCL_ESIMD_EMULATOR_PLUGIN_NAME;
dllMap.emplace(esimd_path,
LoadLibraryEx(esimd_path.wstring().c_str(), NULL, NULL));

std::string hip_path = LibSYCLDir + __SYCL_HIP_PLUGIN_NAME;
dllMap.emplace(hip_path, LoadLibraryExA(hip_path.c_str(), NULL, NULL));
auto hip_path = LibSYCLDir / __SYCL_HIP_PLUGIN_NAME;
dllMap.emplace(hip_path,
LoadLibraryEx(hip_path.wstring().c_str(), NULL, NULL));

std::string ur_path = LibSYCLDir + __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME;
dllMap.emplace(ur_path, LoadLibraryExA(ur_path.c_str(), NULL, NULL));
auto ur_path = LibSYCLDir / __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME;
dllMap.emplace(ur_path, LoadLibraryEx(ur_path.wstring().c_str(), NULL, NULL));

std::string nativecpu_path = LibSYCLDir + __SYCL_NATIVE_CPU_PLUGIN_NAME;
auto nativecpu_path = LibSYCLDir / __SYCL_NATIVE_CPU_PLUGIN_NAME;
dllMap.emplace(nativecpu_path,
LoadLibraryExA(nativecpu_path.c_str(), NULL, NULL));
LoadLibraryEx(nativecpu_path.wstring().c_str(), NULL, NULL));

// Restore system error handling.
(void)SetErrorMode(SavedMode);
if (!SetDllDirectoryA(nullptr)) {
if (!SetDllDirectory(nullptr)) {
assert(false && "Failed to restore DLL search path");
}
}

/// windows_pi.cpp:loadOsPluginLibrary() calls this to get the DLL loaded
/// earlier.
__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
__declspec(dllexport) void *getPreloadedPlugin(
const std::filesystem::path &PluginPath) {

MapT &dllMap = getDllMap();

auto match = dllMap.find(PluginPath); // result might be nullptr (not found),
// which is perfectly valid.
if (match == dllMap.end()) {
// unit testing? return nullptr (not found) rather than risk asserting below
if (PluginPath.find("unittests") != std::string::npos)
if (PluginPath.string().find("unittests") != std::string::npos)
return nullptr;

// Otherwise, asking for something we don't know about at all, is an issue.
Expand All @@ -200,6 +188,10 @@ __declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
return match->second;
}

__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
return getPreloadedPlugin(std::filesystem::path(PluginPath));
}

BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module
DWORD fdwReason, // reason for calling function
LPVOID lpReserved) // reserved
Expand Down
4 changes: 4 additions & 0 deletions sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
#pragma once

#ifdef _WIN32
#include <filesystem>
#include <string>

__declspec(dllexport) void *getPreloadedPlugin(
const std::filesystem::path &PluginPath);
// TODO: Remove this version during ABI breakage window
__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath);
#endif
19 changes: 2 additions & 17 deletions sycl/source/detail/os_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#elif defined(__SYCL_RT_OS_WINDOWS)

#include <detail/windows_os_utils.hpp>

#include <Windows.h>
#include <direct.h>
#include <malloc.h>
Expand Down Expand Up @@ -139,23 +141,6 @@ std::string OSUtil::getDirName(const char *Path) {
}

#elif defined(__SYCL_RT_OS_WINDOWS)
// TODO: Just inline it.
using OSModuleHandle = intptr_t;
static constexpr OSModuleHandle ExeModuleHandle = -1;
static OSModuleHandle getOSModuleHandle(const void *VirtAddr) {
HMODULE PhModule;
DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT;
auto LpModuleAddr = reinterpret_cast<LPCSTR>(VirtAddr);
if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) {
// Expect the caller to check for zero and take
// necessary action
return 0;
}
if (PhModule == GetModuleHandleA(nullptr))
return ExeModuleHandle;
return reinterpret_cast<OSModuleHandle>(PhModule);
}

/// Returns an absolute path where the object was found.
// pi_win_proxy_loader.dll uses this same logic. If it is changed
Expand Down
35 changes: 21 additions & 14 deletions sycl/source/detail/pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <sstream>
#include <stddef.h>
#include <string>
#include <tuple>

#ifdef XPTI_ENABLE_INSTRUMENTATION
// Include the headers necessary for emitting
Expand Down Expand Up @@ -435,47 +436,53 @@ std::vector<PluginPtr> &initialize() {
return GlobalHandler::instance().getPlugins();
}

// Implementation of this function is OS specific. Please see windows_pi.cpp and
// posix_pi.cpp.
// TODO: refactor code when support matrix for DPCPP changes and <filesystem> is
// available on all supported systems.
std::vector<std::tuple<std::string, backend, void *>>
loadPlugins(const std::vector<std::pair<std::string, backend>> &&PluginNames);

static void initializePlugins(std::vector<PluginPtr> &Plugins) {
std::vector<std::pair<std::string, backend>> PluginNames = findPlugins();
const std::vector<std::pair<std::string, backend>> PluginNames =
findPlugins();

if (PluginNames.empty() && trace(PI_TRACE_ALL))
std::cerr << "SYCL_PI_TRACE[all]: "
<< "No Plugins Found." << std::endl;

const std::string LibSYCLDir =
sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep;
// Get library handles for the list of plugins.
std::vector<std::tuple<std::string, backend, void *>> LoadedPlugins =
loadPlugins(std::move(PluginNames));

for (unsigned int I = 0; I < PluginNames.size(); I++) {
for (auto [Name, Backend, Library] : LoadedPlugins) {
std::shared_ptr<PiPlugin> PluginInformation = std::make_shared<PiPlugin>(
PiPlugin{_PI_H_VERSION_STRING, _PI_H_VERSION_STRING,
/*Targets=*/nullptr, /*FunctionPointers=*/{}});

void *Library = loadPlugin(LibSYCLDir + PluginNames[I].first);

if (!Library) {
if (trace(PI_TRACE_ALL)) {
std::cerr << "SYCL_PI_TRACE[all]: "
<< "Check if plugin is present. "
<< "Failed to load plugin: " << PluginNames[I].first
<< std::endl;
<< "Failed to load plugin: " << Name << std::endl;
}
continue;
}

if (!bindPlugin(Library, PluginInformation)) {
if (trace(PI_TRACE_ALL)) {
std::cerr << "SYCL_PI_TRACE[all]: "
<< "Failed to bind PI APIs to the plugin: "
<< PluginNames[I].first << std::endl;
<< "Failed to bind PI APIs to the plugin: " << Name
<< std::endl;
}
continue;
}
PluginPtr &NewPlugin = Plugins.emplace_back(std::make_shared<plugin>(
PluginInformation, PluginNames[I].second, Library));
PluginPtr &NewPlugin = Plugins.emplace_back(
std::make_shared<plugin>(PluginInformation, Backend, Library));
if (trace(TraceLevel::PI_TRACE_BASIC))
std::cerr << "SYCL_PI_TRACE[basic]: "
<< "Plugin found and successfully loaded: "
<< PluginNames[I].first << " [ PluginVersion: "
<< "Plugin found and successfully loaded: " << Name
<< " [ PluginVersion: "
<< NewPlugin->getPiPlugin().PluginVersion << " ]" << std::endl;
}

Expand Down
16 changes: 16 additions & 0 deletions sycl/source/detail/posix_pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) {
return dlsym(Library, FunctionName.c_str());
}

// Load plugins corresponding to provided list of plugin names.
std::vector<std::tuple<std::string, backend, void *>>
loadPlugins(const std::vector<std::pair<std::string, backend>> &&PluginNames) {
std::vector<std::tuple<std::string, backend, void *>> LoadedPlugins;
const std::string LibSYCLDir =
sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep;

for (auto &PluginName : PluginNames) {
void *Library = loadOsPluginLibrary(LibSYCLDir + PluginName.first);
LoadedPlugins.push_back(std::make_tuple(
std::move(PluginName.first), std::move(PluginName.second), Library));
}

return LoadedPlugins;
}

} // namespace detail::pi
} // namespace _V1
} // namespace sycl
28 changes: 28 additions & 0 deletions sycl/source/detail/windows_os_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//==-- windows_os_utils.hpp - Header file with common utils for Windows --==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===--------------------------------------------------------------------===//

#pragma once

#include <shlwapi.h>

using OSModuleHandle = intptr_t;
constexpr OSModuleHandle ExeModuleHandle = -1;
inline OSModuleHandle getOSModuleHandle(const void *VirtAddr) {
HMODULE PhModule;
DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT;
auto LpModuleAddr = reinterpret_cast<LPCSTR>(VirtAddr);
if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) {
// Expect the caller to check for zero and take
// necessary action
return 0;
}
if (PhModule == GetModuleHandleA(nullptr))
return ExeModuleHandle;
return reinterpret_cast<OSModuleHandle>(PhModule);
}
35 changes: 35 additions & 0 deletions sycl/source/detail/windows_pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
//
//===----------------------------------------------------------------------===//

#include <sycl/backend.hpp>
#include <sycl/detail/defines.hpp>

#include <cassert>
#include <string>
#include <windows.h>
#include <winreg.h>

#include "detail/windows_os_utils.hpp"
#include "pi_win_proxy_loader.hpp"

namespace sycl {
Expand Down Expand Up @@ -66,6 +68,39 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) {
GetProcAddress((HMODULE)Library, FunctionName.c_str()));
}

static std::filesystem::path getCurrentDSODirPath() {
wchar_t Path[MAX_PATH];
auto Handle =
getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODirPath));
DWORD Ret = GetModuleFileName(
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle), Path,
sizeof(Path));
assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?");
assert(Ret > 0 && "GetModuleFileName failed");
(void)Ret;

BOOL RetCode = PathRemoveFileSpec(Path);
assert(RetCode && "PathRemoveFileSpec failed");
(void)RetCode;

return std::filesystem::path(Path);
}

// Load plugins corresponding to provided list of plugin names.
std::vector<std::tuple<std::string, backend, void *>>
loadPlugins(const std::vector<std::pair<std::string, backend>> &&PluginNames) {
std::vector<std::tuple<std::string, backend, void *>> LoadedPlugins;
const std::filesystem::path LibSYCLDir = getCurrentDSODirPath();

for (auto &PluginName : PluginNames) {
void *Library = getPreloadedPlugin(LibSYCLDir / PluginName.first);
LoadedPlugins.push_back(std::make_tuple(
std::move(PluginName.first), std::move(PluginName.second), Library));
}

return LoadedPlugins;
}

} // namespace pi
} // namespace detail
} // namespace _V1
Expand Down

0 comments on commit 5c30815

Please sign in to comment.