From 1b6b4051b2057eb1b3f56401bcddb7bbd9bc670d Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Fri, 15 Nov 2024 09:40:03 +0100 Subject: [PATCH] cleanup Binding (#107) * cleanup Binding * fix Binding constructors with ComPtr --------- Co-authored-by: Simon Kallweit --- include/slang-rhi.h | 41 +++++++++++++++++++++---------- src/cpu/cpu-shader-object.cpp | 10 +++++--- src/cuda/cuda-shader-object.cpp | 6 ++--- src/d3d11/d3d11-shader-object.cpp | 10 +++++--- src/d3d12/d3d12-shader-object.cpp | 22 +++++++++++------ src/metal/metal-shader-object.cpp | 8 +++--- src/vulkan/vk-shader-object.cpp | 24 ++++++++++-------- src/wgpu/wgpu-shader-object.cpp | 8 +++--- 8 files changed, 80 insertions(+), 49 deletions(-) diff --git a/include/slang-rhi.h b/include/slang-rhi.h index a9f36bf..3572ff0 100644 --- a/include/slang-rhi.h +++ b/include/slang-rhi.h @@ -1088,14 +1088,15 @@ enum class BindingType TextureView, Sampler, CombinedTextureSampler, + CombinedTextureViewSampler, AccelerationStructure, }; struct Binding { - BindingType type; - ComPtr resource; - ComPtr resource2; + BindingType type = BindingType::Unknown; + IResource* resource = nullptr; + IResource* resource2 = nullptr; union { BufferRange bufferRange; @@ -1103,16 +1104,30 @@ struct Binding // clang-format off Binding() : type(BindingType::Unknown) {} - Binding(ComPtr buffer, const BufferRange& range = kEntireBuffer) : type(BindingType::Buffer), resource(buffer), bufferRange(range) {} - Binding(IBuffer* buffer, const BufferRange& range = kEntireBuffer) : Binding(ComPtr(buffer), range) {} - Binding(ComPtr buffer, ComPtr counter, const BufferRange& range = kEntireBuffer) : type(BindingType::BufferWithCounter), resource(buffer), resource2(counter), bufferRange(range) {} - Binding(ComPtr texture) : type(BindingType::Texture), resource(texture) {} - Binding(ComPtr textureView) : type(BindingType::TextureView), resource(textureView) {} - Binding(ComPtr sampler) : type(BindingType::Sampler) , resource(sampler) {} - Binding(ComPtr textureView, ComPtr sampler) : type(BindingType::CombinedTextureSampler), resource(textureView), resource2(sampler) {} - Binding(ComPtr texture, ComPtr sampler) : type(BindingType::CombinedTextureSampler) , resource(texture), resource2(sampler) {} - Binding(ComPtr as) : type(BindingType::AccelerationStructure) , resource(as) {} - Binding(IAccelerationStructure* as) : Binding(ComPtr(as)) {} + + Binding(IBuffer* buffer, const BufferRange& range = kEntireBuffer) : type(BindingType::Buffer), resource(buffer), bufferRange(range) {} + Binding(const ComPtr& buffer, const BufferRange& range = kEntireBuffer) : type(BindingType::Buffer), resource(buffer.get()), bufferRange(range) {} + + Binding(IBuffer* buffer, IBuffer* counter, const BufferRange& range = kEntireBuffer) : type(BindingType::BufferWithCounter), resource(buffer), resource2(counter), bufferRange(range) {} + Binding(const ComPtr& buffer, const ComPtr& counter, const BufferRange& range = kEntireBuffer) : type(BindingType::BufferWithCounter), resource(buffer.get()), resource2(counter.get()), bufferRange(range) {} + + Binding(ITexture* texture) : type(BindingType::Texture), resource(texture) {} + Binding(const ComPtr& texture) : type(BindingType::Texture), resource(texture.get()) {} + + Binding(ITextureView* textureView) : type(BindingType::TextureView), resource(textureView) {} + Binding(const ComPtr& textureView) : type(BindingType::TextureView), resource(textureView.get()) {} + + Binding(ISampler* sampler) : type(BindingType::Sampler) , resource(sampler) {} + Binding(const ComPtr& sampler) : type(BindingType::Sampler) , resource(sampler.get()) {} + + Binding(ITexture* texture, ISampler* sampler) : type(BindingType::CombinedTextureSampler), resource(texture), resource2(sampler) {} + Binding(const ComPtr& texture, const ComPtr& sampler) : type(BindingType::CombinedTextureSampler), resource(texture.get()), resource2(sampler.get()) {} + + Binding(ITextureView* textureView, ISampler* sampler) : type(BindingType::CombinedTextureViewSampler) , resource(textureView), resource2(sampler) {} + Binding(const ComPtr& textureView, const ComPtr& sampler) : type(BindingType::CombinedTextureViewSampler) , resource(textureView.get()), resource2(sampler.get()) {} + + Binding(IAccelerationStructure* as) : type(BindingType::AccelerationStructure), resource(as) {} + Binding(const ComPtr& as) : type(BindingType::AccelerationStructure), resource(as.get()) {} // clang-format on }; diff --git a/src/cpu/cpu-shader-object.cpp b/src/cpu/cpu-shader-object.cpp index 4a3ab90..87f530d 100644 --- a/src/cpu/cpu-shader-object.cpp +++ b/src/cpu/cpu-shader-object.cpp @@ -162,7 +162,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) { case BindingType::Buffer: { - BufferImpl* buffer = checked_cast(binding.resource.get()); + BufferImpl* buffer = checked_cast(binding.resource); const BufferDesc& desc = buffer->m_desc; BufferRange range = buffer->resolveBufferRange(binding.bufferRange); m_resources[viewIndex] = buffer; @@ -182,12 +182,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); return setBinding(offset, m_device->createTextureView(texture, {})); } case BindingType::TextureView: { - auto textureView = checked_cast(binding.resource.get()); + auto textureView = checked_cast(binding.resource); m_resources[viewIndex] = textureView; slang_prelude::IRWTexture* textureObj = textureView; SLANG_RETURN_ON_FAIL(setData(offset, &textureObj, sizeof(textureObj))); @@ -201,6 +201,10 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) { break; } + case BindingType::CombinedTextureViewSampler: + { + break; + } case BindingType::AccelerationStructure: { break; diff --git a/src/cuda/cuda-shader-object.cpp b/src/cuda/cuda-shader-object.cpp index ea0b3e7..71fd39c 100644 --- a/src/cuda/cuda-shader-object.cpp +++ b/src/cuda/cuda-shader-object.cpp @@ -205,7 +205,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) { case BindingType::Buffer: { - BufferImpl* buffer = checked_cast(binding.resource.get()); + BufferImpl* buffer = checked_cast(binding.resource); const BufferDesc& desc = buffer->m_desc; BufferRange range = buffer->resolveBufferRange(binding.bufferRange); m_resources[viewIndex] = buffer; @@ -222,7 +222,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); m_resources[viewIndex] = texture; switch (bindingRange.bindingType) { @@ -237,7 +237,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::TextureView: { - TextureViewImpl* textureView = checked_cast(binding.resource.get()); + TextureViewImpl* textureView = checked_cast(binding.resource); m_resources[viewIndex] = textureView; TextureImpl* texture = textureView->m_texture; switch (bindingRange.bindingType) diff --git a/src/d3d11/d3d11-shader-object.cpp b/src/d3d11/d3d11-shader-object.cpp index 4b29dd7..f6fde13 100644 --- a/src/d3d11/d3d11-shader-object.cpp +++ b/src/d3d11/d3d11-shader-object.cpp @@ -58,7 +58,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) { case BindingType::Buffer: { - BufferImpl* buffer = checked_cast(binding.resource.get()); + BufferImpl* buffer = checked_cast(binding.resource); BufferRange bufferRange = buffer->resolveBufferRange(binding.bufferRange); m_resources.emplace(buffer); if (D3DUtil::isUAVBinding(bindingRange.bindingType)) @@ -73,12 +73,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); return setBinding(offset, m_device->createTextureView(texture, {})); } case BindingType::TextureView: { - TextureViewImpl* textureView = checked_cast(binding.resource.get()); + TextureViewImpl* textureView = checked_cast(binding.resource); m_resources.emplace(textureView); if (D3DUtil::isUAVBinding(bindingRange.bindingType)) { @@ -91,10 +91,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) break; } case BindingType::Sampler: - m_samplers[bindingIndex] = checked_cast(binding.resource.get()); + m_samplers[bindingIndex] = checked_cast(binding.resource); break; case BindingType::CombinedTextureSampler: break; + case BindingType::CombinedTextureViewSampler: + break; case BindingType::AccelerationStructure: break; } diff --git a/src/d3d12/d3d12-shader-object.cpp b/src/d3d12/d3d12-shader-object.cpp index 23eb3e2..c3baea2 100644 --- a/src/d3d12/d3d12-shader-object.cpp +++ b/src/d3d12/d3d12-shader-object.cpp @@ -765,8 +765,8 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) case BindingType::Buffer: case BindingType::BufferWithCounter: { - BufferImpl* buffer = checked_cast(binding.resource.get()); - BufferImpl* counterBuffer = checked_cast(binding.resource2.get()); + BufferImpl* buffer = checked_cast(binding.resource); + BufferImpl* counterBuffer = checked_cast(binding.resource2); BufferRange bufferRange = buffer->resolveBufferRange(binding.bufferRange); boundResource.type = BoundResourceType::Buffer; boundResource.resource = buffer; @@ -811,12 +811,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); return setBinding(offset, m_device->createTextureView(texture, {})); } case BindingType::TextureView: { - TextureViewImpl* textureView = checked_cast(binding.resource.get()); + TextureViewImpl* textureView = checked_cast(binding.resource); boundResource.type = BoundResourceType::TextureView; boundResource.resource = textureView; D3D12Descriptor descriptor; @@ -843,7 +843,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Sampler: { - SamplerImpl* sampler = checked_cast(binding.resource.get()); + SamplerImpl* sampler = checked_cast(binding.resource); d3dDevice->CopyDescriptorsSimple( 1, m_descriptorSet.samplerTable.getCpuHandle(bindingIndex), @@ -854,8 +854,14 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::CombinedTextureSampler: { - TextureViewImpl* textureView = checked_cast(binding.resource.get()); - SamplerImpl* sampler = checked_cast(binding.resource2.get()); + TextureImpl* texture = checked_cast(binding.resource); + SamplerImpl* sampler = checked_cast(binding.resource2); + return setBinding(offset, Binding(m_device->createTextureView(texture, {}), sampler)); + } + case BindingType::CombinedTextureViewSampler: + { + TextureViewImpl* textureView = checked_cast(binding.resource); + SamplerImpl* sampler = checked_cast(binding.resource2); boundResource.type = BoundResourceType::TextureView; boundResource.resource = textureView; boundResource.requiredState = ResourceState::ShaderResource; @@ -875,7 +881,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::AccelerationStructure: { - AccelerationStructureImpl* as = checked_cast(binding.resource.get()); + AccelerationStructureImpl* as = checked_cast(binding.resource); boundResource.type = BoundResourceType::AccelerationStructure; boundResource.resource = as; if (bindingRange.isRootParameter) diff --git a/src/metal/metal-shader-object.cpp b/src/metal/metal-shader-object.cpp index ddd36d0..81e5fdd 100644 --- a/src/metal/metal-shader-object.cpp +++ b/src/metal/metal-shader-object.cpp @@ -64,19 +64,19 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) switch (binding.type) { case BindingType::Buffer: - m_buffers[bindingIndex] = checked_cast(binding.resource.get()); + m_buffers[bindingIndex] = checked_cast(binding.resource); m_bufferOffsets[bindingIndex] = binding.bufferRange.offset; break; case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); return setBinding(offset, m_device->createTextureView(texture, {})); } case BindingType::TextureView: - m_textureViews[bindingIndex] = checked_cast(binding.resource.get()); + m_textureViews[bindingIndex] = checked_cast(binding.resource); break; case BindingType::Sampler: - m_samplers[bindingIndex] = checked_cast(binding.resource.get()); + m_samplers[bindingIndex] = checked_cast(binding.resource); break; } diff --git a/src/vulkan/vk-shader-object.cpp b/src/vulkan/vk-shader-object.cpp index eacdd93..3ffa884 100644 --- a/src/vulkan/vk-shader-object.cpp +++ b/src/vulkan/vk-shader-object.cpp @@ -92,7 +92,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) { case BindingType::Buffer: { - BufferImpl* buffer = checked_cast(binding.resource.get()); + BufferImpl* buffer = checked_cast(binding.resource); slot.type = BindingType::Buffer; slot.resource = buffer; slot.format = slot.format != Format::Unknown ? slot.format : buffer->m_desc.format; @@ -112,13 +112,13 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); return setBinding(offset, m_device->createTextureView(texture, {})); } case BindingType::TextureView: { slot.type = BindingType::TextureView; - slot.resource = checked_cast(binding.resource.get()); + slot.resource = checked_cast(binding.resource); switch (bindingRange.bindingType) { case slang::BindingType::Texture: @@ -131,20 +131,24 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) break; } case BindingType::Sampler: - m_samplers[bindingIndex] = checked_cast(binding.resource.get()); + m_samplers[bindingIndex] = checked_cast(binding.resource); break; case BindingType::CombinedTextureSampler: { - TextureImpl* texture = checked_cast(binding.resource.get()); - m_combinedTextureSamplers[bindingIndex] = CombinedTextureSamplerSlot{ - checked_cast(m_device->createTextureView(texture, {}).get()), - checked_cast(binding.resource2.get()) - }; + TextureImpl* texture = checked_cast(binding.resource); + SamplerImpl* sampler = checked_cast(binding.resource2); + return setBinding(offset, Binding(m_device->createTextureView(texture, {}), sampler)); + } + case BindingType::CombinedTextureViewSampler: + { + TextureViewImpl* textureView = checked_cast(binding.resource); + SamplerImpl* sampler = checked_cast(binding.resource2); + m_combinedTextureSamplers[bindingIndex] = CombinedTextureSamplerSlot{textureView, sampler}; break; } case BindingType::AccelerationStructure: slot.type = BindingType::AccelerationStructure; - slot.resource = checked_cast(binding.resource.get()); + slot.resource = checked_cast(binding.resource); break; } diff --git a/src/wgpu/wgpu-shader-object.cpp b/src/wgpu/wgpu-shader-object.cpp index 0eebd31..64fe60f 100644 --- a/src/wgpu/wgpu-shader-object.cpp +++ b/src/wgpu/wgpu-shader-object.cpp @@ -88,7 +88,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) { case BindingType::Buffer: { - BufferImpl* buffer = checked_cast(binding.resource.get()); + BufferImpl* buffer = checked_cast(binding.resource); ResourceSlot slot; slot.type = BindingType::Buffer; slot.resource = buffer; @@ -99,19 +99,19 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding) } case BindingType::Texture: { - TextureImpl* texture = checked_cast(binding.resource.get()); + TextureImpl* texture = checked_cast(binding.resource); return setBinding(offset, m_device->createTextureView(texture, {})); } case BindingType::TextureView: { ResourceSlot slot; slot.type = BindingType::TextureView; - slot.resource = checked_cast(binding.resource.get()); + slot.resource = checked_cast(binding.resource); m_resources[bindingIndex] = slot; break; } case BindingType::Sampler: - m_samplers[bindingIndex] = checked_cast(binding.resource.get()); + m_samplers[bindingIndex] = checked_cast(binding.resource); break; }