Skip to content

Commit

Permalink
metal fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
skallweitNV committed Sep 2, 2024
1 parent 1aaa02b commit 5fe9c33
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 148 deletions.
10 changes: 10 additions & 0 deletions CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@
}
],
"buildPresets": [
{
"name": "debug",
"configurePreset": "default",
"configuration": "Debug"
},
{
"name": "release",
"configurePreset": "default",
"configuration": "Release"
},
{
"name": "msvc-debug",
"configurePreset": "msvc",
Expand Down
3 changes: 2 additions & 1 deletion src/metal/metal-base.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class AccelerationStructureImpl;
class FramebufferLayoutImpl;
class RenderPassLayoutImpl;
class FramebufferImpl;
class PipelineImpl;
class RenderPipelineImpl;
class ComputePipelineImpl;
class RayTracingPipelineImpl;
class ShaderObjectLayoutImpl;
class EntryPointLayout;
Expand Down
49 changes: 16 additions & 33 deletions src/metal/metal-command-encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ void PipelineCommandEncoder::endEncodingImpl()
m_commandBuffer->endMetalCommandEncoder();
}

Result PipelineCommandEncoder::setPipelineImpl(IPipeline* pipeline, IShaderObject** outRootObject)
Result PipelineCommandEncoder::setPipelineImpl(PipelineBase* pipeline, IShaderObject** outRootObject)
{
m_currentPipeline = static_cast<PipelineImpl*>(pipeline);
m_currentPipeline = pipeline;
// m_commandBuffer->m_mutableRootShaderObject = nullptr;
SLANG_RETURN_ON_FAIL(m_commandBuffer->m_rootObject.init(
m_commandBuffer->m_device,
Expand Down Expand Up @@ -277,7 +277,7 @@ void RenderCommandEncoder::endEncoding()

Result RenderCommandEncoder::bindPipeline(IRenderPipeline* pipeline, IShaderObject** outRootObject)
{
return setPipelineImpl(pipeline, outRootObject);
return setPipelineImpl(static_cast<RenderPipelineImpl*>(pipeline), outRootObject);
}

Result RenderCommandEncoder::bindPipelineWithRootObject(IRenderPipeline* pipeline, IShaderObject* rootObject)
Expand Down Expand Up @@ -373,8 +373,7 @@ Result RenderCommandEncoder::setSamplePositions(

Result RenderCommandEncoder::prepareDraw(MTL::RenderCommandEncoder*& encoder)
{
auto pipeline = static_cast<PipelineImpl*>(m_currentPipeline.Ptr());
pipeline->ensureAPIPipelineCreated();
auto pipeline = static_cast<RenderPipelineImpl*>(m_currentPipeline.get());

encoder = m_commandBuffer->getMetalRenderCommandEncoder(m_renderPassDesc.get());
encoder->setRenderPipelineState(pipeline->m_renderPipelineState.get());
Expand All @@ -386,25 +385,21 @@ Result RenderCommandEncoder::prepareDraw(MTL::RenderCommandEncoder*& encoder)

for (Index i = 0; i < m_vertexBuffers.size(); ++i)
{
encoder->setVertexBuffer(
m_vertexBuffers[i],
m_vertexBufferOffsets[i],
m_currentPipeline->m_vertexBufferOffset + i
);
encoder->setVertexBuffer(m_vertexBuffers[i], m_vertexBufferOffsets[i], pipeline->m_vertexBufferOffset + i);
}

encoder->setViewports(m_viewports.data(), m_viewports.size());
encoder->setScissorRects(m_scissorRects.data(), m_scissorRects.size());

const RasterizerDesc& rasterDesc = pipeline->desc.graphics.rasterizer;
const DepthStencilDesc& depthStencilDesc = pipeline->desc.graphics.depthStencil;
encoder->setFrontFacingWinding(MetalUtil::translateWinding(rasterDesc.frontFace));
encoder->setCullMode(MetalUtil::translateCullMode(rasterDesc.cullMode));
const RasterizerDesc& rasterizerDesc = pipeline->m_rasterizerDesc;
const DepthStencilDesc& depthStencilDesc = pipeline->m_depthStencilDesc;
encoder->setFrontFacingWinding(MetalUtil::translateWinding(rasterizerDesc.frontFace));
encoder->setCullMode(MetalUtil::translateCullMode(rasterizerDesc.cullMode));
encoder->setDepthClipMode(
rasterDesc.depthClipEnable ? MTL::DepthClipModeClip : MTL::DepthClipModeClamp
rasterizerDesc.depthClipEnable ? MTL::DepthClipModeClip : MTL::DepthClipModeClamp
); // TODO correct?
encoder->setDepthBias(rasterDesc.depthBias, rasterDesc.slopeScaledDepthBias, rasterDesc.depthBiasClamp);
encoder->setTriangleFillMode(MetalUtil::translateTriangleFillMode(rasterDesc.fillMode));
encoder->setDepthBias(rasterizerDesc.depthBias, rasterizerDesc.slopeScaledDepthBias, rasterizerDesc.depthBiasClamp);
encoder->setTriangleFillMode(MetalUtil::translateTriangleFillMode(rasterizerDesc.fillMode));
// encoder->setBlendColor(); // not supported by rhi
if (m_framebuffer->m_depthStencilView)
{
Expand Down Expand Up @@ -502,7 +497,7 @@ void ComputeCommandEncoder::endEncoding()

Result ComputeCommandEncoder::bindPipeline(IComputePipeline* pipeline, IShaderObject** outRootObject)
{
return setPipelineImpl(pipeline, outRootObject);
return setPipelineImpl(static_cast<ComputePipelineImpl*>(pipeline), outRootObject);
}

Result ComputeCommandEncoder::bindPipelineWithRootObject(IComputePipeline* pipeline, IShaderObject* rootObject)
Expand All @@ -519,21 +514,9 @@ Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z)
auto program = static_cast<ShaderProgramImpl*>(m_currentPipeline->m_program.get());
m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout);

auto pipeline = static_cast<PipelineImpl*>(m_currentPipeline.Ptr());
RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject;
RefPtr<PipelineBase> newPipeline;
SLANG_RETURN_ON_FAIL(
m_commandBuffer->m_device->maybeSpecializePipeline(m_currentPipeline, rootObjectImpl, newPipeline)
);
PipelineImpl* newPipelineImpl = static_cast<PipelineImpl*>(newPipeline.Ptr());

SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineCreated());
m_currentPipeline = newPipelineImpl;

m_currentPipeline->ensureAPIPipelineCreated();
encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get());

encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize);
auto pipeline = static_cast<ComputePipelineImpl*>(m_currentPipeline.get());
encoder->setComputePipelineState(pipeline->m_computePipelineState.get());
encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize);

return SLANG_OK;
}
Expand Down
4 changes: 2 additions & 2 deletions src/metal/metal-command-encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class PipelineCommandEncoder : public ComObject
public:
CommandBufferImpl* m_commandBuffer;
MTL::CommandBuffer* m_metalCommandBuffer;
RefPtr<PipelineImpl> m_currentPipeline;
RefPtr<PipelineBase> m_currentPipeline;

void init(CommandBufferImpl* commandBuffer);
void endEncodingImpl();

Result setPipelineImpl(IPipeline* pipeline, IShaderObject** outRootObject);
Result setPipelineImpl(PipelineBase* pipeline, IShaderObject** outRootObject);
};

class ResourceCommandEncoder : public IResourceCommandEncoder, public PipelineCommandEncoder
Expand Down
9 changes: 4 additions & 5 deletions src/metal/metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,8 @@ Result DeviceImpl::createRenderPipeline(const RenderPipelineDesc& desc, IRenderP
{
AUTORELEASEPOOL

RefPtr<PipelineImpl> pipelineImpl = new PipelineImpl(this);
pipelineImpl->init(desc);
RefPtr<RenderPipelineImpl> pipelineImpl = new RenderPipelineImpl(this);
SLANG_RETURN_ON_FAIL(pipelineImpl->init(desc));
returnComPtr(outPipeline, pipelineImpl);
return SLANG_OK;
}
Expand All @@ -773,9 +773,8 @@ Result DeviceImpl::createComputePipeline(const ComputePipelineDesc& desc, ICompu
{
AUTORELEASEPOOL

RefPtr<PipelineImpl> pipelineImpl = new PipelineImpl(this);
pipelineImpl->init(desc);
m_deviceObjectsWithPotentialBackReferences.push_back(pipelineImpl);
RefPtr<ComputePipelineImpl> pipelineImpl = new ComputePipelineImpl(this);
SLANG_RETURN_ON_FAIL(pipelineImpl->init(desc));
returnComPtr(outPipeline, pipelineImpl);
return SLANG_OK;
}
Expand Down
122 changes: 41 additions & 81 deletions src/metal/metal-pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,25 @@

namespace rhi::metal {

PipelineImpl::PipelineImpl(DeviceImpl* device)
RenderPipelineImpl::RenderPipelineImpl(DeviceImpl* device)
: m_device(device)
{
}

PipelineImpl::~PipelineImpl() {}
RenderPipelineImpl::~RenderPipelineImpl() {}

void PipelineImpl::init(const RenderPipelineDesc& desc)
Result RenderPipelineImpl::init(const RenderPipelineDesc& desc)
{
PipelineStateDesc pipelineDesc;
pipelineDesc.type = PipelineType::Graphics;
pipelineDesc.graphics = desc;
initializeBase(pipelineDesc);
}

void PipelineImpl::init(const ComputePipelineDesc& desc)
{
PipelineStateDesc pipelineDesc;
pipelineDesc.type = PipelineType::Compute;
pipelineDesc.compute = desc;
initializeBase(pipelineDesc);
}
SLANG_RETURN_ON_FAIL(RenderPipelineBase::init(desc));

void PipelineImpl::init(const RayTracingPipelineDesc& desc)
{
PipelineStateDesc pipelineDesc;
pipelineDesc.type = PipelineType::RayTracing;
pipelineDesc.rayTracing.set(desc);
initializeBase(pipelineDesc);
}
m_rasterizerDesc = desc.rasterizer;
m_depthStencilDesc = desc.depthStencil;

Result PipelineImpl::createMetalRenderPipelineState()
{
auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr());
if (!programImpl)
return SLANG_FAIL;
auto programImpl = static_cast<ShaderProgramImpl*>(m_program.get());
if (programImpl->m_modules.empty())
{
SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device));
}

NS::SharedPtr<MTL::RenderPipelineDescriptor> pd = NS::TransferPtr(MTL::RenderPipelineDescriptor::alloc()->init());

Expand Down Expand Up @@ -70,15 +53,15 @@ Result PipelineImpl::createMetalRenderPipelineState()
// They need to be in a range not used by any buffers in the root object layout.
// The +1 is to account for a potential constant buffer at index 0.
m_vertexBufferOffset = programImpl->m_rootObjectLayout->getBufferCount() + 1;
auto inputLayoutImpl = static_cast<InputLayoutImpl*>(desc.graphics.inputLayout);
auto inputLayoutImpl = static_cast<InputLayoutImpl*>(desc.inputLayout);
NS::SharedPtr<MTL::VertexDescriptor> vertexDescriptor =
inputLayoutImpl->createVertexDescriptor(m_vertexBufferOffset);
pd->setVertexDescriptor(vertexDescriptor.get());
pd->setInputPrimitiveTopology(MetalUtil::translatePrimitiveTopologyClass(desc.graphics.primitiveType));
pd->setInputPrimitiveTopology(MetalUtil::translatePrimitiveTopologyClass(desc.primitiveType));

// Set rasterization state
auto framebufferLayoutImpl = static_cast<FramebufferLayoutImpl*>(desc.graphics.framebufferLayout);
const auto& blend = desc.graphics.blend;
auto framebufferLayoutImpl = static_cast<FramebufferLayoutImpl*>(desc.framebufferLayout);
const auto& blend = desc.blend;
GfxCount sampleCount = 1;

pd->setAlphaToCoverageEnabled(blend.alphaToCoverageEnable);
Expand Down Expand Up @@ -147,7 +130,7 @@ Result PipelineImpl::createMetalRenderPipelineState()
return stencilDesc;
};

const auto& depthStencil = desc.graphics.depthStencil;
const auto& depthStencil = desc.depthStencil;
NS::SharedPtr<MTL::DepthStencilDescriptor> depthStencilDesc =
NS::TransferPtr(MTL::DepthStencilDescriptor::alloc()->init());
m_depthStencilState = NS::TransferPtr(m_device->m_device->newDepthStencilState(depthStencilDesc.get()));
Expand All @@ -173,11 +156,29 @@ Result PipelineImpl::createMetalRenderPipelineState()
return SLANG_OK;
}

Result PipelineImpl::createMetalComputePipelineState()
Result RenderPipelineImpl::getNativeHandle(NativeHandle* outHandle)
{
auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr());
if (!programImpl)
return SLANG_FAIL;
outHandle->type = NativeHandleType::MTLRenderPipelineState;
outHandle->value = (uint64_t)m_renderPipelineState.get();
return SLANG_OK;
}

ComputePipelineImpl::ComputePipelineImpl(DeviceImpl* device)
: m_device(device)
{
}

ComputePipelineImpl::~ComputePipelineImpl() {}

Result ComputePipelineImpl::init(const ComputePipelineDesc& desc)
{
SLANG_RETURN_ON_FAIL(ComputePipelineBase::init(desc));

auto programImpl = static_cast<ShaderProgramImpl*>(m_program.get());
if (programImpl->m_modules.empty())
{
SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device));
}

const ShaderProgramImpl::Module& module = programImpl->m_modules[0];
auto functionName = MetalUtil::createString(module.entryPointName.data());
Expand All @@ -196,52 +197,11 @@ Result PipelineImpl::createMetalComputePipelineState()
return m_computePipelineState ? SLANG_OK : SLANG_FAIL;
}

Result PipelineImpl::ensureAPIPipelineCreated()
Result ComputePipelineImpl::getNativeHandle(NativeHandle* outHandle)
{
AUTORELEASEPOOL

switch (desc.type)
{
case PipelineType::Compute:
return m_computePipelineState ? SLANG_OK : createMetalComputePipelineState();
case PipelineType::Graphics:
return m_renderPipelineState ? SLANG_OK : createMetalRenderPipelineState();
default:
SLANG_RHI_UNREACHABLE("Unknown pipeline type.");
return SLANG_FAIL;
}
outHandle->type = NativeHandleType::MTLComputePipelineState;
outHandle->value = (uint64_t)m_computePipelineState.get();
return SLANG_OK;
}

SLANG_NO_THROW Result SLANG_MCALL PipelineImpl::getNativeHandle(NativeHandle* outHandle)
{
switch (desc.type)
{
case PipelineType::Compute:
outHandle->type = NativeHandleType::MTLComputePipelineState;
outHandle->value = (uint64_t)m_computePipelineState.get();
return SLANG_OK;
case PipelineType::Graphics:
outHandle->type = NativeHandleType::MTLRenderPipelineState;
outHandle->value = (uint64_t)m_renderPipelineState.get();
return SLANG_OK;
}
return SLANG_E_NOT_AVAILABLE;
}

RayTracingPipelineImpl::RayTracingPipelineImpl(DeviceImpl* device)
: PipelineImpl(device)
{
}

Result RayTracingPipelineImpl::ensureAPIPipelineCreated()
{
return SLANG_E_NOT_IMPLEMENTED;
}

Result RayTracingPipelineImpl::getNativeHandle(NativeHandle* outHandle)
{
return SLANG_E_NOT_IMPLEMENTED;
}

} // namespace rhi::metal
43 changes: 17 additions & 26 deletions src/metal/metal-pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,33 @@

namespace rhi::metal {

class PipelineImpl : public PipelineBase
class RenderPipelineImpl : public RenderPipelineBase
{
public:
DeviceImpl* m_device;
RenderPipelineImpl(DeviceImpl* device);
virtual ~RenderPipelineImpl() override;
Result init(const RenderPipelineDesc& desc);
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(NativeHandle* outHandle) override;

RefPtr<DeviceImpl> m_device;
RasterizerDesc m_rasterizerDesc;
DepthStencilDesc m_depthStencilDesc;
NS::SharedPtr<MTL::RenderPipelineState> m_renderPipelineState;
NS::SharedPtr<MTL::DepthStencilState> m_depthStencilState;
NS::SharedPtr<MTL::ComputePipelineState> m_computePipelineState;
MTL::Size m_threadGroupSize;
NS::UInteger m_vertexBufferOffset;

PipelineImpl(DeviceImpl* device);
~PipelineImpl();

void init(const RenderPipelineDesc& desc);
void init(const ComputePipelineDesc& desc);
void init(const RayTracingPipelineDesc& desc);

Result createMetalComputePipelineState();
Result createMetalRenderPipelineState();

virtual Result ensureAPIPipelineCreated() override;

virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(NativeHandle* outHandle) override;
};

class RayTracingPipelineImpl : public PipelineImpl
class ComputePipelineImpl : public ComputePipelineBase
{
public:
std::map<std::string, Index> shaderGroupNameToIndex;
Int shaderGroupCount;

RayTracingPipelineImpl(DeviceImpl* device);

virtual Result ensureAPIPipelineCreated() override;

ComputePipelineImpl(DeviceImpl* device);
virtual ~ComputePipelineImpl() override;
Result init(const ComputePipelineDesc& desc);
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(NativeHandle* outHandle) override;

RefPtr<DeviceImpl> m_device;
NS::SharedPtr<MTL::ComputePipelineState> m_computePipelineState;
MTL::Size m_threadGroupSize;
};

} // namespace rhi::metal

0 comments on commit 5fe9c33

Please sign in to comment.