Skip to content

Commit

Permalink
improve availability check and move it to testing
Browse files Browse the repository at this point in the history
  • Loading branch information
skallweitNV committed Nov 12, 2024
1 parent 8960fc2 commit 7de88c3
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 50 deletions.
2 changes: 0 additions & 2 deletions include/slang-rhi.h
Original file line number Diff line number Diff line change
Expand Up @@ -2409,8 +2409,6 @@ class IRHI

virtual SLANG_NO_THROW bool SLANG_MCALL isDeviceTypeSupported(DeviceType type) = 0;

virtual SLANG_NO_THROW bool SLANG_MCALL isDeviceTypeAvailable(DeviceType type) = 0;

/// Gets a list of available adapters for a given device type.
virtual SLANG_NO_THROW Result SLANG_MCALL getAdapters(DeviceType type, ISlangBlob** outAdaptersBlob) = 0;

Expand Down
47 changes: 0 additions & 47 deletions src/rhi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ class RHI : public IRHI
virtual const FormatInfo& getFormatInfo(Format format) override { return s_formatInfoMap.get(format); }
virtual const char* getDeviceTypeName(DeviceType type) override;
virtual bool isDeviceTypeSupported(DeviceType type) override;
virtual bool isDeviceTypeAvailable(DeviceType type) override;

Result getAdapters(DeviceType type, ISlangBlob** outAdaptersBlob) override;
Result createDevice(const DeviceDesc& desc, IDevice** outDevice) override;
Expand Down Expand Up @@ -283,12 +282,7 @@ bool RHI::isDeviceTypeSupported(DeviceType type)
case DeviceType::Metal:
return SLANG_RHI_ENABLE_METAL;
case DeviceType::CPU:
#if SLANG_LINUX_FAMILY
// Known issues with CPU backend on linux.
return false;
#else
return SLANG_RHI_ENABLE_CPU;
#endif
case DeviceType::CUDA:
#if SLANG_RHI_ENABLE_CUDA
return rhiCudaApiInit();
Expand All @@ -302,47 +296,6 @@ bool RHI::isDeviceTypeSupported(DeviceType type)
}
}

bool RHI::isDeviceTypeAvailable(DeviceType type)
{
if (!isDeviceTypeSupported(type))
return false;

// Try creating a device.
ComPtr<IDevice> device;
DeviceDesc desc;
desc.deviceType = type;
SLANG_RETURN_FALSE_ON_FAIL(createDevice(desc, device.writeRef()));

// Try compiling a trivial shader.
ComPtr<slang::ISession> session = device->getSlangSession();
if (!session)
return false;

const char* source = "[shader(\"compute\")] [numthreads(1,1,1)] void main(uint3 tid : SV_DispatchThreadID) {}";
slang::IModule* module = session->loadModuleFromSourceString("test", "test", source);
if (!module)
return false;

ComPtr<slang::IEntryPoint> entryPoint;
SLANG_RETURN_FALSE_ON_FAIL(module->findEntryPointByName("main", entryPoint.writeRef()));

std::vector<slang::IComponentType*> componentTypes;
componentTypes.push_back(module);
componentTypes.push_back(entryPoint);
ComPtr<slang::IComponentType> composedProgram;
SLANG_RETURN_FALSE_ON_FAIL(
session->createCompositeComponentType(componentTypes.data(), componentTypes.size(), composedProgram.writeRef())
);

ComPtr<slang::IComponentType> linkedProgram;
SLANG_RETURN_FALSE_ON_FAIL(composedProgram->link(linkedProgram.writeRef()));

ComPtr<slang::IBlob> code;
SLANG_RETURN_FALSE_ON_FAIL(linkedProgram->getEntryPointCode(0, 0, code.writeRef()));

return true;
}

Result RHI::getAdapters(DeviceType type, ISlangBlob** outAdaptersBlob)
{
std::vector<AdapterInfo> adapters;
Expand Down
94 changes: 93 additions & 1 deletion tests/testing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,105 @@ inline const char* deviceTypeToString(DeviceType deviceType)
}
}

inline bool checkDeviceTypeAvailable(DeviceType deviceType, bool verbose = true)
{
#define RETURN_NOT_AVAILABLE(msg) \
if (verbose) \
MESSAGE(doctest::String(deviceTypeToString(deviceType)), " is not available (", doctest::String(msg), ")"); \
return false;

if (verbose)
MESSAGE("Checking for ", doctest::String(deviceTypeToString(deviceType)));


if (!rhi::getRHI()->isDeviceTypeSupported(deviceType))
RETURN_NOT_AVAILABLE("Device type not supported");

#if SLANG_LINUX_FAMILY
if (deviceType == DeviceType::CPU)
// Known issues with CPU backend on linux.
RETURN_NOT_AVAILABLE("CPU backend not supported on linux");
#endif

// Try creating a device.
ComPtr<IDevice> device;
DeviceDesc desc;
desc.deviceType = deviceType;
if (!SLANG_SUCCEEDED(rhi::getRHI()->createDevice(desc, device.writeRef())))
RETURN_NOT_AVAILABLE("Failed to create device");

// Try compiling a trivial shader.
ComPtr<slang::ISession> session = device->getSlangSession();
if (!session)
return false;

// Load shader module.
slang::IModule* module = nullptr;
{
ComPtr<slang::IBlob> diagnostics;
const char* source = "[shader(\"compute\")] [numthreads(1,1,1)] void main(uint3 tid : SV_DispatchThreadID) {}";
slang::IModule* module = session->loadModuleFromSourceString("test", "test", source, diagnostics.writeRef());
if (verbose && diagnostics)
MESSAGE(doctest::String((const char*)diagnostics->getBufferPointer()));
if (!module)
RETURN_NOT_AVAILABLE("Failed to load module");
}

ComPtr<slang::IEntryPoint> entryPoint;
if (!SLANG_SUCCEEDED(module->findEntryPointByName("main", entryPoint.writeRef())))
RETURN_NOT_AVAILABLE("Failed to find entry point");

ComPtr<slang::IComponentType> composedProgram;
{
ComPtr<slang::IBlob> diagnostics;
std::vector<slang::IComponentType*> componentTypes;
componentTypes.push_back(module);
componentTypes.push_back(entryPoint);
session->createCompositeComponentType(
componentTypes.data(),
componentTypes.size(),
composedProgram.writeRef(),
diagnostics.writeRef()
);
if (verbose && diagnostics)
MESSAGE(doctest::String((const char*)diagnostics->getBufferPointer()));
if (!composedProgram)
RETURN_NOT_AVAILABLE("Failed to create composite component type");
}

ComPtr<slang::IComponentType> linkedProgram;
{
ComPtr<slang::IBlob> diagnostics;
composedProgram->link(linkedProgram.writeRef(), diagnostics.writeRef());
if (verbose && diagnostics)
MESSAGE(doctest::String((const char*)diagnostics->getBufferPointer()));
if (!linkedProgram)
RETURN_NOT_AVAILABLE("Failed to link program");
}

ComPtr<slang::IBlob> code;
{
ComPtr<slang::IBlob> diagnostics;
linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnostics.writeRef());
if (verbose && diagnostics)
MESSAGE(doctest::String((const char*)diagnostics->getBufferPointer()));
if (!code)
RETURN_NOT_AVAILABLE("Failed to get entry point code");
}

if (verbose)
MESSAGE(doctest::String(deviceTypeToString(deviceType)), " is available.");
return true;
}


bool isDeviceTypeAvailable(DeviceType deviceType)
{
static std::map<DeviceType, bool> available;
auto it = available.find(deviceType);
if (it == available.end())
{
available[deviceType] = getRHI()->isDeviceTypeAvailable(deviceType);
available[deviceType] = checkDeviceTypeAvailable(deviceType);
}
return available[deviceType];
}
Expand Down

0 comments on commit 7de88c3

Please sign in to comment.