Skip to content

Commit

Permalink
extract required spv binary from exe
Browse files Browse the repository at this point in the history
Signed-off-by: jinge90 <[email protected]>
  • Loading branch information
jinge90 committed Nov 25, 2024
1 parent 2976cde commit 331644e
Showing 1 changed file with 102 additions and 10 deletions.
112 changes: 102 additions & 10 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <map>
#include <memory>
#include <mutex>
#include <sstream>
Expand Down Expand Up @@ -1112,8 +1113,8 @@ ProgramManager::getProgramBuildLog(const ur_program_handle_t &Program,
// TODO device libraries may use scpecialization constants, manifest files, etc.
// To support that they need to be delivered in a different container - so that
// sycl_device_binary_struct can be created for each of them.
static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
ur_program_handle_t &Prog) {
static bool loadDeviceLibLegacy(const ContextImplPtr Context, const char *Name,
ur_program_handle_t &Prog) {
std::string LibSyclDir = OSUtil::getCurrentDSODir();
std::ifstream File(LibSyclDir + OSUtil::DirSep + Name,
std::ifstream::in | std::ifstream::binary);
Expand All @@ -1133,6 +1134,13 @@ static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
return Prog != nullptr;
}

static bool loadDeviceLib(const ContextImplPtr Context,
ur_program_handle_t &Prog,
const unsigned char *SPVBuffer, size_t SPVSize) {
Prog = createSpirvProgram(Context, SPVBuffer, SPVSize);
return Prog != nullptr;
}

// For each extension, a pair of library names. The first uses native support,
// the second emulates functionality in software.
static const std::map<DeviceLibExt, std::pair<const char *, const char *>>
Expand Down Expand Up @@ -1213,9 +1221,13 @@ static ur_result_t doCompile(const AdapterPtr &Adapter,
static ur_program_handle_t
loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
std::vector<ur_device_handle_t> &Devices,
bool UseNativeLib) {
bool UseNativeLib, bool LegacyMode = true,
const unsigned char *SPVBuffer = nullptr,
size_t SPVSize = 0) {

auto LibFileName = getDeviceLibFilename(Extension, UseNativeLib);
const char *LibFileName = nullptr;
if (LegacyMode)
LibFileName = getDeviceLibFilename(Extension, UseNativeLib);
auto LockedCache = Context->acquireCachedLibPrograms();
auto &CachedLibPrograms = LockedCache.get();
// Collect list of devices to compile the library for. Library was already
Expand Down Expand Up @@ -1252,10 +1264,20 @@ loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
bool IsProgramCreated = !URProgram;

// Create UR program for device lib if we don't have it yet.
if (!URProgram && !loadDeviceLib(Context, LibFileName, URProgram)) {
EraseProgramForDevices();
throw exception(make_error_code(errc::build),
std::string("Failed to load ") + LibFileName);
if (LegacyMode) {
if (!URProgram && !loadDeviceLibLegacy(Context, LibFileName, URProgram)) {
EraseProgramForDevices();
throw exception(make_error_code(errc::build),
std::string("Failed to load ") + LibFileName);
}
} else {
if (!URProgram && !loadDeviceLib(Context, URProgram, SPVBuffer, SPVSize)) {
EraseProgramForDevices();
const char *ExtStr = getDeviceLibExtensionStr(Extension);
throw exception(
make_error_code(errc::build),
std::string("Failed to load fallback device library for ") + ExtStr);
}
}

// Insert URProgram into the cache for all devices that we compiled it for.
Expand Down Expand Up @@ -1513,6 +1535,8 @@ static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
return ((DeviceLibReqMask & Mask) == Mask);
}

// TODO: Clear legacy getDeviceLibPrograms when developers upgrade to
// latest version compiler.
static std::vector<ur_program_handle_t>
getDeviceLibProgramsLegacy(const ContextImplPtr Context,
std::vector<ur_device_handle_t> &Devices,
Expand Down Expand Up @@ -1604,6 +1628,38 @@ getDeviceLibPrograms(const ContextImplPtr Context,
std::vector<ur_device_handle_t> &Devices,
const std::vector<const RTDeviceBinaryImage *> &Images) {
std::vector<ur_program_handle_t> Programs;
std::map<DeviceLibExt, bool> DeviceLibExtLoaded = {
{DeviceLibExt::cl_intel_devicelib_assert,
/* is fallback loaded? */ false},
{DeviceLibExt::cl_intel_devicelib_math, false},
{DeviceLibExt::cl_intel_devicelib_math_fp64, false},
{DeviceLibExt::cl_intel_devicelib_complex, false},
{DeviceLibExt::cl_intel_devicelib_complex_fp64, false},
{DeviceLibExt::cl_intel_devicelib_cstring, false},
{DeviceLibExt::cl_intel_devicelib_imf, false},
{DeviceLibExt::cl_intel_devicelib_imf_fp64, false},
{DeviceLibExt::cl_intel_devicelib_imf_bf16, false},
{DeviceLibExt::cl_intel_devicelib_bfloat16, false}};

// Check whether a specified extension is supported by ALL devices.
auto checkExtForDevices = [&Context, &Devices](const char *ExtStr) -> bool {
bool ExtAvailable = true;
for (auto SingleDevice : Devices) {
std::string DevExtList =
Context->getPlatformImpl()
->getDeviceImpl(SingleDevice)
->get_device_info_string(
UrInfoCode<info::device::extensions>::value);
if (DevExtList.npos == DevExtList.find(ExtStr)) {
ExtAvailable = false;
break;
}
}
return ExtAvailable;
};

const bool fp64Support = checkExtForDevices("cl_khr_fp64");

for (auto Img : Images) {
if (!Img)
continue;
Expand All @@ -1616,11 +1672,47 @@ getDeviceLibPrograms(const ContextImplPtr Context,
auto DeviceLibByteArray =
DeviceBinaryProperty(DeviceLibBinProp).asByteArray();
DeviceLibByteArray.dropBytes(8);
uint32_t DeviceLibExtReq =
DeviceLibExt DeviceLibExtReq = static_cast<DeviceLibExt>(
(static_cast<uint32_t>(DeviceLibByteArray[3]) << 24) |
(static_cast<uint32_t>(DeviceLibByteArray[2]) << 16) |
(static_cast<uint32_t>(DeviceLibByteArray[1]) << 8) |
DeviceLibByteArray[0];
DeviceLibByteArray[0]);
if (DeviceLibExtLoaded.count(DeviceLibExtReq) != 1) {
if constexpr (DbgProgMgr > 0) {
std::cerr << "Unknown DeviceLib extension("
<< static_cast<uint32_t>(DeviceLibExtReq) << ")!"
<< std::endl;
}
continue;
}

if (DeviceLibExtLoaded[DeviceLibExtReq])
continue;

if ((DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_complex_fp64 ||
DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_imf_fp64) &&
!fp64Support)
continue;

auto DeviceLibExtReqName = getDeviceLibExtensionStr(DeviceLibExtReq);
bool InhibitNativeImpl = false;
if (const char *Env = getenv("SYCL_DEVICELIB_INHIBIT_NATIVE")) {
InhibitNativeImpl = strstr(Env, DeviceLibExtReqName) != nullptr;
}

bool ExtReqAvailable = checkExtForDevices(DeviceLibExtReqName);

// Load fallback device library only when 1) or 2) is met:
// 1. underlying device doesn't support the extension
// 2. user explicitly ask to inhibit usage of native support
if (!ExtReqAvailable || InhibitNativeImpl) {
DeviceLibByteArray.dropBytes(4);
Programs.push_back(loadDeviceLibFallback(
Context, DeviceLibExtReq, Devices,
/*UseNativeLib=*/false, false, DeviceLibByteArray.begin(),
DeviceLibByteArray.size()));
}
}
}
return Programs;
Expand Down

0 comments on commit 331644e

Please sign in to comment.