Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix precompiledTargetModule tests #6236

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ if(SLANG_ENABLE_GFX)
EXPORT_SET_NAME SlangTargets
FOLDER gfx
)
target_link_libraries(gfx PUBLIC SPIRV-Tools-link)
set(modules_dest_dir $<TARGET_FILE_DIR:slang-test>)
add_custom_target(
copy-gfx-slang-modules
Expand Down
7 changes: 4 additions & 3 deletions tools/gfx/d3d12/d3d12-shader-program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ using namespace Slang;

Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob> > kernelCodes)
{
ShaderBinary shaderBin;
shaderBin.stage = entryPointInfo->getStage();
shaderBin.entryPointInfo = entryPointInfo;
SLANG_ASSERT(kernelCodes.getCount() == 1);
shaderBin.code.addRange(
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
(Index)kernelCode->getBufferSize());
reinterpret_cast<const uint8_t*>(kernelCodes[0]->getBufferPointer()),
(Index)kernelCodes[0]->getBufferSize());
m_shaders.add(_Move(shaderBin));
return SLANG_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/d3d12/d3d12-shader-program.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
List<ComPtr<ISlangBlob> > kernelCodes) override;
};

} // namespace d3d12
Expand Down
10 changes: 6 additions & 4 deletions tools/gfx/metal/metal-shader-program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ ShaderProgramImpl::~ShaderProgramImpl() {}

Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
Slang::List<ComPtr<ISlangBlob> > kernelCodes)
{
Module module;
module.stage = entryPointInfo->getStage();
module.entryPointName = entryPointInfo->getNameOverride();
module.code = kernelCode;
SLANG_ASSERT(kernelCodes.getCount() == 1);
module.code = kernelCodes[0];


dispatch_data_t data = dispatch_data_create(
kernelCode->getBufferPointer(),
kernelCode->getBufferSize(),
kernelCodes[0]->getBufferPointer(),
kernelCodes[0]->getBufferSize(),
dispatch_get_main_queue(),
NULL);
NS::Error* error;
Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/metal/metal-shader-program.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
Slang::List<ComPtr<ISlangBlob> > kernelCodes) override;
};


Expand Down
115 changes: 95 additions & 20 deletions tools/gfx/renderer-shared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,26 +1109,101 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
SlangInt entryPointIndex)
{
auto stage = entryPointInfo->getStage();
ComPtr<ISlangBlob> kernelCode;
ComPtr<ISlangBlob> diagnostics;
auto compileResult = device->getEntryPointCodeFromShaderCache(
entryPointComponent,
entryPointIndex,
0,
kernelCode.writeRef(),
diagnostics.writeRef());
if (diagnostics)

List<ComPtr<ISlangBlob>> kernelCodes;
{
ComPtr<ISlangBlob> spirv;
ComPtr<ISlangBlob> diagnostics;
auto compileResult = device->getEntryPointCodeFromShaderCache(
entryPointComponent,
entryPointIndex,
0,
spirv.writeRef(),
diagnostics.writeRef());
if (diagnostics)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (compileResult != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnostics->getBufferPointer());
}
SLANG_RETURN_ON_FAIL(compileResult);
kernelCodes.add(spirv);
}

// If target precompilation was used, kernelCode may only represent the
// glue code holding together the bits of precompiled target IR.
// Collect those dependency target IRs too.
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
if (entryPointComponent->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (compileResult != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnostics->getBufferPointer());
SlangInt dependencyCount = componentPrecompileService->getModuleDependencyCount();
if (dependencyCount > 0)
{
for (int dependencyIndex = 0; dependencyIndex < dependencyCount; dependencyIndex++)
{
ComPtr<slang::IModule> dependencyModule;
{
ComPtr<slang::IBlob> diagnosticsBlob;
auto result = componentPrecompileService->getModuleDependency(
dependencyIndex,
dependencyModule.writeRef(),
diagnosticsBlob.writeRef());
if (diagnosticsBlob)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (result != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnosticsBlob->getBufferPointer());
}
SLANG_RETURN_ON_FAIL(result);
}

ComPtr<slang::IBlob> spirv;
{
ComPtr<slang::IBlob> diagnosticsBlob;
SlangResult result = SLANG_OK;
ComPtr<slang::IModulePrecompileService_Experimental> precompileService;
result = dependencyModule->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)precompileService.writeRef());
if (result == SLANG_OK)
{
ComPtr<slang::IBlob> diagnosticsBlob;
auto result = precompileService->getPrecompiledTargetCode(
SLANG_SPIRV,
spirv.writeRef(),
diagnosticsBlob.writeRef());
if (result == SLANG_OK)
{
kernelCodes.add(spirv);
}
if (diagnosticsBlob)
{
DebugMessageType msgType = DebugMessageType::Warning;
if (result != SLANG_OK)
msgType = DebugMessageType::Error;
getDebugCallback()->handleMessage(
msgType,
DebugMessageSource::Slang,
(char*)diagnosticsBlob->getBufferPointer());
}
}
SLANG_RETURN_ON_FAIL(result);
}
}
}
}
SLANG_RETURN_ON_FAIL(compileResult);
SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));

SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCodes));
return SLANG_OK;
};

Expand Down Expand Up @@ -1160,10 +1235,10 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)

Result ShaderProgramBase::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
Slang::List<ComPtr<ISlangBlob> > kernelCodes)
{
SLANG_UNUSED(entryPointInfo);
SLANG_UNUSED(kernelCode);
SLANG_UNUSED(kernelCodes);
return SLANG_OK;
}

Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/renderer-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ class ShaderProgramBase : public IShaderProgram, public Slang::ComObject
Slang::Result compileShaders(RendererBase* device);
virtual Slang::Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
Slang::ComPtr<ISlangBlob> kernelCode);
Slang::List<Slang::ComPtr<ISlangBlob> > kernelCodes);

virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL
findTypeByName(const char* name) override
Expand Down
1 change: 1 addition & 0 deletions tools/gfx/vulkan/vk-pipeline-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ Result PipelineStateImpl::createVKComputePipelineState()

VkComputePipelineCreateInfo computePipelineInfo = {
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
SLANG_ASSERT(programImpl->m_stageCreateInfos.getCount() == 1);
computePipelineInfo.stage = programImpl->m_stageCreateInfos[0];
computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;

Expand Down
51 changes: 48 additions & 3 deletions tools/gfx/vulkan/vk-shader-program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "vk-device.h"
#include "vk-util.h"

#include "external/spirv-tools/include/spirv-tools/linker.hpp"

namespace gfx
{

Expand Down Expand Up @@ -69,17 +71,60 @@ VkPipelineShaderStageCreateInfo ShaderProgramImpl::compileEntryPoint(
return shaderStageCreateInfo;
}

static ComPtr<ISlangBlob> LinkWithSPIRVTools(List<ComPtr<ISlangBlob> > kernelCodes)
{
spvtools::Context context(SPV_ENV_UNIVERSAL_1_5);
spvtools::LinkerOptions options;
spvtools::MessageConsumer consumer = [](spv_message_level_t level,
const char* source,
const spv_position_t& position,
const char* message)
{
printf("SPIRV-TOOLS: %s\n", message);
printf("SPIRV-TOOLS: %s\n", source);
printf("SPIRV-TOOLS: %zu:%zu\n", position.index, position.column);
};
context.SetMessageConsumer(consumer);
std::vector<uint32_t*> binaries;
std::vector<size_t> binary_sizes;
for (auto kernelCode : kernelCodes)
{
binaries.push_back((uint32_t*)kernelCode->getBufferPointer());
binary_sizes.push_back(kernelCode->getBufferSize() / sizeof(uint32_t));
}

std::vector<uint32_t> linked_binary;

spvtools::Link(
context,
binaries.data(),
binary_sizes.data(),
binaries.size(),
&linked_binary,
options);

// Create a blob to hold the linked binary
ComPtr<ISlangBlob> linkedKernelCode;

// Replace kernel code with linked binary
// Creates a new blob with the linked binary
linkedKernelCode = RawBlob::create(linked_binary.data(), linked_binary.size() * sizeof(uint32_t));

return linkedKernelCode;
}
Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode)
List<ComPtr<ISlangBlob>> kernelCodes)
{
m_codeBlobs.add(kernelCode);
ComPtr<ISlangBlob> linkedKernel = LinkWithSPIRVTools(kernelCodes);
m_codeBlobs.add(linkedKernel);

VkShaderModule shaderModule;
auto realEntryPointName = entryPointInfo->getNameOverride();
const char* spirvBinaryEntryPointName = "main";
m_stageCreateInfos.add(compileEntryPoint(
spirvBinaryEntryPointName,
kernelCode,
linkedKernel,
(VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()),
shaderModule));
m_entryPointNames.add(realEntryPointName);
Expand Down
2 changes: 1 addition & 1 deletion tools/gfx/vulkan/vk-shader-program.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ShaderProgramImpl : public ShaderProgramBase

virtual Result createShaderModule(
slang::EntryPointReflection* entryPointInfo,
ComPtr<ISlangBlob> kernelCode) override;
List<ComPtr<ISlangBlob> > kernelCodes) override;
};


Expand Down
Loading