Skip to content

Commit

Permalink
Fix precompiledTargetModule tests
Browse files Browse the repository at this point in the history
In the SPIR-V backend of Slang, compiling a shader
that contains some modules with precompiled target
blobs will produce only a "glue" SPIR-V output which
needs to be linked with the assorted precompiled
blobs to be complete.

Closes #6170
  • Loading branch information
cheneym2 committed Jan 31, 2025
1 parent a5b1aa0 commit 3f2ed2a
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 33 deletions.
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ if(SLANG_ENABLE_GFX)
EXPORT_SET_NAME SlangTargets
FOLDER gfx
)
target_link_libraries(gfx PUBLIC SPIRV-Tools-link)
set(modules_dest_dir $<TARGET_FILE_DIR:slang-test>)
add_custom_target(
copy-gfx-slang-modules
Expand Down
7 changes: 4 additions & 3 deletions tools/gfx/d3d12/d3d12-shader-program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ using namespace Slang;

Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob> > kernelCodes)
{
ShaderBinary shaderBin;
shaderBin.stage = entryPointInfo->getStage();
shaderBin.entryPointInfo = entryPointInfo;
assert(kernelCodes.getCount() == 1); // Only one kernel code is supported for now
shaderBin.code.addRange(
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
(Index)kernelCode->getBufferSize());
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
(Index)kernelCodes[0]->getBufferSize());
m_shaders.add(_Move(shaderBin));
return SLANG_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/d3d12/d3d12-shader-program.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
List<ComPtr<ISlangBlob> > kernelCodes) override;
};

} // namespace d3d12
Expand Down
10 changes: 6 additions & 4 deletions tools/gfx/metal/metal-shader-program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ ShaderProgramImpl::~ShaderProgramImpl() {}

Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
Slang::List<ComPtr<ISlangBlob> > kernelCodes)
{
Module module;
module.stage = entryPointInfo->getStage();
module.entryPointName = entryPointInfo->getNameOverride();
module.code = kernelCode;
assert(kernelCodes.getCount() == 1);
module.code = kernelCodes[0];


dispatch_data_t data = dispatch_data_create(
kernelCode->getBufferPointer(),
kernelCode->getBufferSize(),
kernelCodes[0]->getBufferPointer(),
kernelCodes[0]->getBufferSize(),
dispatch_get_main_queue(),
NULL);
NS::Error* error;
Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/metal/metal-shader-program.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
Slang::List<ComPtr<ISlangBlob> > kernelCodes) override;
};


Expand Down
113 changes: 94 additions & 19 deletions tools/gfx/renderer-shared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,26 +1109,101 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
SlangInt entryPointIndex)
{
auto stage = entryPointInfo->getStage();

List<ComPtr<ISlangBlob> > kernelCodes;
ComPtr<ISlangBlob> kernelCode;
ComPtr<ISlangBlob> diagnostics;
auto compileResult = device->getEntryPointCodeFromShaderCache(
entryPointComponent,
entryPointIndex,
0,
kernelCode.writeRef(),
diagnostics.writeRef());
if (diagnostics)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (compileResult != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnostics->getBufferPointer());
ComPtr<ISlangBlob> diagnostics;
auto compileResult = device->getEntryPointCodeFromShaderCache(
entryPointComponent,
entryPointIndex,
0,
kernelCode.writeRef(),
diagnostics.writeRef());
if (diagnostics)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (compileResult != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnostics->getBufferPointer());
}
SLANG_RETURN_ON_FAIL(compileResult);
kernelCodes.add(kernelCode);
}

// If target precompilation was used, kernelCode may only represent the
// glue code holding together the bits of precompiled target IR.
// Collect those dependency target IRs too.
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
if (entryPointComponent->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
{
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
if (dependencyCount > 0)
{
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
{
ComPtr<slang::IModule> dependencyModule;
{
ComPtr<slang::IBlob> diagnosticsBlob;
auto result = componentPrecompileService->getModuleDependency(
dependencyIndex,
dependencyModule.writeRef(),
diagnosticsBlob.writeRef());
if (diagnosticsBlob)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (result != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnosticsBlob->getBufferPointer());
}
SLANG_RETURN_ON_FAIL(result);
}

ComPtr<slang::IBlob> spirv;
{
ComPtr<slang::IBlob> diagnosticsBlob;
SlangResult result = SLANG_OK;
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
result = dependencyModule->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)precompileService.writeRef());
if (result == SLANG_OK)
{
ComPtr<slang::IBlob> diagnosticsBlob;
auto result = precompileService->getPrecompiledTargetCode(
SLANG_SPIRV,
spirv.writeRef(),
diagnosticsBlob.writeRef());
if (result == SLANG_OK)
{
kernelCodes.add(spirv);
}
if (diagnosticsBlob)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (result != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnosticsBlob->getBufferPointer());
}
}
SLANG_RETURN_ON_FAIL(result);
}
}
}
}
SLANG_RETURN_ON_FAIL(compileResult);
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));

SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
return SLANG_OK;
};

Expand Down Expand Up @@ -1160,10 +1235,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)

Result ShaderProgramBase::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
Slang::List<ComPtr<ISlangBlob> > kernelCodes)
{
SLANG_UNUSED(entryPointInfo);
SLANG_UNUSED(kernelCode);
SLANG_UNUSED(kernelCodes);
return SLANG_OK;
}

Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/renderer-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
Slang::Result compileShaders(RendererBase* device);
virtual Slang::Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
Slang::ComPtr<ISlangBlob> kernelCode);
Slang::List<Slang::ComPtr<ISlangBlob> > kernelCodes);

virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
findTypeByName(const char* name) override
Expand Down
1 change: 1 addition & 0 deletions tools/gfx/vulkan/vk-pipeline-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ Result PipelineStateImpl::createVKComputePipelineState()

VkComputePipelineCreateInfo computePipelineInfo = {
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
assert(programImpl->m_stageCreateInfos.getCount() == 1);
computePipelineInfo.stage = programImpl->m_stageCreateInfos[0];
computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;

Expand Down
55 changes: 52 additions & 3 deletions tools/gfx/vulkan/vk-shader-program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "vk-device.h"
#include "vk-util.h"

#include "external/spirv-tools/include/spirv-tools/linker.hpp"

namespace gfx
{

Expand Down Expand Up @@ -69,19 +71,66 @@ VkPipelineShaderStageCreateInfo ShaderProgramImpl::compileEntryPoint(
return shaderStageCreateInfo;
}

static ComPtr<ISlangBlob> LinkWithSPIRVTools(List<ComPtr<ISlangBlob> > kernelCodes)
{
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
spvtools::LinkerOptions options;
spvtools::MessageConsumer consumer = [](spv_message_level_t level,
const char* source,
const spv_position_t& position,
const char* message)
{
printf("SPIRV-TOOLS: %s\n", message);
printf("SPIRV-TOOLS: %s\n", source);
printf("SPIRV-TOOLS: %d:%d\n", position.index, position.column);
};
context.SetMessageConsumer(consumer);
std::vector<uint32_t*> binaries;
std::vector<size_t> binary_sizes;
for (auto kernelCode : kernelCodes)
{
binaries.push_back((uint32_t*)kernelCode->getBufferPointer());
binary_sizes.push_back(kernelCode->getBufferSize() / sizeof(uint32_t));
}

std::vector<uint32_t> linked_binary;

spvtools::Link(
context,
binaries.data(),
binary_sizes.data(),
binaries.size(),
&linked_binary,
options);

// Create a blob to hold the linked binary
ComPtr<ISlangBlob> linkedKernelCode;

// Replace kernel code with linked binary
// Creates a new blob with the linked binary
linkedKernelCode = RawBlob::create(linked_binary.data(), linked_binary.size() * sizeof(uint32_t));

return linkedKernelCode;
}
Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob>> kernelCodes)
{
m_codeBlobs.add(kernelCode);
//for (auto kernelCode : kernelCodes)
// m_codeBlobs.add(kernelCode);

ComPtr<ISlangBlob> linkedKernel = LinkWithSPIRVTools(kernelCodes);
m_codeBlobs.add(linkedKernel);

VkShaderModule shaderModule;
auto realEntryPointName = entryPointInfo->getNameOverride();
const char* spirvBinaryEntryPointName = "main";
m_stageCreateInfos.add(compileEntryPoint(
spirvBinaryEntryPointName,
kernelCode,
linkedKernel,
(VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()),
shaderModule));

m_entryPointNames.add(realEntryPointName);
m_modules.add(shaderModule);
return SLANG_OK;
Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/vulkan/vk-shader-program.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
List<ComPtr<ISlangBlob> > kernelCodes) override;
};


Expand Down

0 comments on commit 3f2ed2a

Please sign in to comment.