Skip to content

Commit

Permalink
move acceleration structure utils (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
skallweitNV authored Oct 21, 2024
1 parent 801865f commit 6386188
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 350 deletions.
120 changes: 0 additions & 120 deletions src/d3d/d3d-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,124 +1098,4 @@ Result D3DUtil::findAdapters(
return SLANG_OK;
}

#if SLANG_RHI_DXR
Result D3DAccelerationStructureInputsBuilder::build(
const AccelerationStructureBuildDesc& buildDesc,
IDebugCallback* callback
)
{
if (buildDesc.inputCount < 1)
{
return SLANG_E_INVALID_ARG;
}

AccelerationStructureBuildInputType type = (AccelerationStructureBuildInputType&)buildDesc.inputs[0];
for (GfxIndex i = 0; i < buildDesc.inputCount; ++i)
{
if ((AccelerationStructureBuildInputType&)buildDesc.inputs[i] != type)
{
return SLANG_E_INVALID_ARG;
}
}

desc.Flags = translateBuildFlags(buildDesc.flags);
switch (buildDesc.mode)
{
case AccelerationStructureBuildMode::Build:
desc.Flags |= D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_UPDATE;
break;
case AccelerationStructureBuildMode::Update:
desc.Flags |= D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PERFORM_UPDATE;
break;
default:
return SLANG_E_INVALID_ARG;
}

switch (type)
{
case AccelerationStructureBuildInputType::Instances:
{
if (buildDesc.inputCount > 1)
{
return SLANG_E_INVALID_ARG;
}
const AccelerationStructureBuildInputInstances& instances =
(const AccelerationStructureBuildInputInstances&)buildDesc.inputs[0];
desc.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL;
desc.NumDescs = 1;
desc.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
desc.InstanceDescs = instances.instanceBuffer.getDeviceAddress();
break;
}
case AccelerationStructureBuildInputType::Triangles:
{
geomDescs.resize(buildDesc.inputCount);
for (GfxIndex i = 0; i < buildDesc.inputCount; ++i)
{
const AccelerationStructureBuildInputTriangles& triangles =
(const AccelerationStructureBuildInputTriangles&)buildDesc.inputs[i];
if (triangles.vertexBufferCount != 1)
{
return SLANG_E_INVALID_ARG;
}
D3D12_RAYTRACING_GEOMETRY_DESC& geomDesc = geomDescs[i];
geomDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
geomDesc.Flags = translateGeometryFlags(triangles.flags);
geomDesc.Triangles.VertexBuffer.StartAddress = triangles.vertexBuffers[0].getDeviceAddress();
geomDesc.Triangles.VertexBuffer.StrideInBytes = triangles.vertexStride;
geomDesc.Triangles.VertexCount = triangles.vertexCount;
geomDesc.Triangles.VertexFormat = D3DUtil::getMapFormat(triangles.vertexFormat);
if (triangles.indexBuffer)
{
geomDesc.Triangles.IndexBuffer = triangles.indexBuffer.getDeviceAddress();
geomDesc.Triangles.IndexCount = triangles.indexCount;
geomDesc.Triangles.IndexFormat = D3DUtil::getIndexFormat(triangles.indexFormat);
}
else
{
geomDesc.Triangles.IndexBuffer = 0;
geomDesc.Triangles.IndexCount = 0;
geomDesc.Triangles.IndexFormat = DXGI_FORMAT_UNKNOWN;
}
geomDesc.Triangles.Transform3x4 =
triangles.preTransformBuffer ? triangles.preTransformBuffer.getDeviceAddress() : 0;
}
desc.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL;
desc.NumDescs = geomDescs.size();
desc.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
desc.pGeometryDescs = geomDescs.data();
break;
}
case AccelerationStructureBuildInputType::ProceduralPrimitives:
{
geomDescs.resize(buildDesc.inputCount);
for (GfxIndex i = 0; i < buildDesc.inputCount; ++i)
{
const AccelerationStructureBuildInputProceduralPrimitives& proceduralPrimitives =
(const AccelerationStructureBuildInputProceduralPrimitives&)buildDesc.inputs[i];
if (proceduralPrimitives.aabbBufferCount != 1)
{
return SLANG_E_INVALID_ARG;
}
D3D12_RAYTRACING_GEOMETRY_DESC& geomDesc = geomDescs[i];
geomDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS;
geomDesc.Flags = translateGeometryFlags(proceduralPrimitives.flags);
geomDesc.AABBs.AABBCount = proceduralPrimitives.primitiveCount;
geomDesc.AABBs.AABBs.StartAddress = proceduralPrimitives.aabbBuffers[0].getDeviceAddress();
geomDesc.AABBs.AABBs.StrideInBytes = proceduralPrimitives.aabbStride;
}
desc.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL;
desc.NumDescs = geomDescs.size();
desc.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
desc.pGeometryDescs = geomDescs.data();
break;
}
default:
return SLANG_E_INVALID_ARG;
}

return SLANG_OK;
}
#endif

} // namespace rhi
48 changes: 0 additions & 48 deletions src/d3d/d3d-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,52 +149,4 @@ class D3DUtil
static Result waitForCrashDumpCompletion(HRESULT res);
};

#if SLANG_RHI_DXR
struct D3DAccelerationStructureInputsBuilder
{
D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS desc = {};
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {};
std::vector<D3D12_RAYTRACING_GEOMETRY_DESC> geomDescs;
Result build(const AccelerationStructureBuildDesc& buildDesc, IDebugCallback* callback);

private:
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS translateBuildFlags(AccelerationStructureBuildFlags flags)
{
static_assert(
uint32_t(AccelerationStructureBuildFlags::None) == D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_NONE
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::AllowUpdate) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_UPDATE
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::AllowCompaction) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_COMPACTION
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::PreferFastTrace) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::PreferFastBuild) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::MinimizeMemory) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_MINIMIZE_MEMORY
);
return (D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS)flags;
}
D3D12_RAYTRACING_GEOMETRY_FLAGS translateGeometryFlags(AccelerationStructureGeometryFlags flags)
{
static_assert(uint32_t(AccelerationStructureGeometryFlags::None) == D3D12_RAYTRACING_GEOMETRY_FLAG_NONE);
static_assert(uint32_t(AccelerationStructureGeometryFlags::Opaque) == D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE);
static_assert(
uint32_t(AccelerationStructureGeometryFlags::NoDuplicateAnyHitInvocation) ==
D3D12_RAYTRACING_GEOMETRY_FLAG_NO_DUPLICATE_ANYHIT_INVOCATION
);
return (D3D12_RAYTRACING_GEOMETRY_FLAGS)flags;
}
};
#endif
} // namespace rhi
118 changes: 118 additions & 0 deletions src/d3d12/d3d12-acceleration-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,124 @@ DeviceAddress AccelerationStructureImpl::getDeviceAddress()
return m_buffer->getDeviceAddress();
}

Result AccelerationStructureInputsBuilder::build(
const AccelerationStructureBuildDesc& buildDesc,
IDebugCallback* callback
)
{
if (buildDesc.inputCount < 1)
{
return SLANG_E_INVALID_ARG;
}

AccelerationStructureBuildInputType type = (AccelerationStructureBuildInputType&)buildDesc.inputs[0];
for (GfxIndex i = 0; i < buildDesc.inputCount; ++i)
{
if ((AccelerationStructureBuildInputType&)buildDesc.inputs[i] != type)
{
return SLANG_E_INVALID_ARG;
}
}

desc.Flags = translateBuildFlags(buildDesc.flags);
switch (buildDesc.mode)
{
case AccelerationStructureBuildMode::Build:
desc.Flags |= D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_UPDATE;
break;
case AccelerationStructureBuildMode::Update:
desc.Flags |= D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PERFORM_UPDATE;
break;
default:
return SLANG_E_INVALID_ARG;
}

switch (type)
{
case AccelerationStructureBuildInputType::Instances:
{
if (buildDesc.inputCount > 1)
{
return SLANG_E_INVALID_ARG;
}
const AccelerationStructureBuildInputInstances& instances =
(const AccelerationStructureBuildInputInstances&)buildDesc.inputs[0];
desc.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL;
desc.NumDescs = 1;
desc.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
desc.InstanceDescs = instances.instanceBuffer.getDeviceAddress();
break;
}
case AccelerationStructureBuildInputType::Triangles:
{
geomDescs.resize(buildDesc.inputCount);
for (GfxIndex i = 0; i < buildDesc.inputCount; ++i)
{
const AccelerationStructureBuildInputTriangles& triangles =
(const AccelerationStructureBuildInputTriangles&)buildDesc.inputs[i];
if (triangles.vertexBufferCount != 1)
{
return SLANG_E_INVALID_ARG;
}
D3D12_RAYTRACING_GEOMETRY_DESC& geomDesc = geomDescs[i];
geomDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
geomDesc.Flags = translateGeometryFlags(triangles.flags);
geomDesc.Triangles.VertexBuffer.StartAddress = triangles.vertexBuffers[0].getDeviceAddress();
geomDesc.Triangles.VertexBuffer.StrideInBytes = triangles.vertexStride;
geomDesc.Triangles.VertexCount = triangles.vertexCount;
geomDesc.Triangles.VertexFormat = D3DUtil::getMapFormat(triangles.vertexFormat);
if (triangles.indexBuffer)
{
geomDesc.Triangles.IndexBuffer = triangles.indexBuffer.getDeviceAddress();
geomDesc.Triangles.IndexCount = triangles.indexCount;
geomDesc.Triangles.IndexFormat = D3DUtil::getIndexFormat(triangles.indexFormat);
}
else
{
geomDesc.Triangles.IndexBuffer = 0;
geomDesc.Triangles.IndexCount = 0;
geomDesc.Triangles.IndexFormat = DXGI_FORMAT_UNKNOWN;
}
geomDesc.Triangles.Transform3x4 =
triangles.preTransformBuffer ? triangles.preTransformBuffer.getDeviceAddress() : 0;
}
desc.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL;
desc.NumDescs = geomDescs.size();
desc.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
desc.pGeometryDescs = geomDescs.data();
break;
}
case AccelerationStructureBuildInputType::ProceduralPrimitives:
{
geomDescs.resize(buildDesc.inputCount);
for (GfxIndex i = 0; i < buildDesc.inputCount; ++i)
{
const AccelerationStructureBuildInputProceduralPrimitives& proceduralPrimitives =
(const AccelerationStructureBuildInputProceduralPrimitives&)buildDesc.inputs[i];
if (proceduralPrimitives.aabbBufferCount != 1)
{
return SLANG_E_INVALID_ARG;
}
D3D12_RAYTRACING_GEOMETRY_DESC& geomDesc = geomDescs[i];
geomDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS;
geomDesc.Flags = translateGeometryFlags(proceduralPrimitives.flags);
geomDesc.AABBs.AABBCount = proceduralPrimitives.primitiveCount;
geomDesc.AABBs.AABBs.StartAddress = proceduralPrimitives.aabbBuffers[0].getDeviceAddress();
geomDesc.AABBs.AABBs.StrideInBytes = proceduralPrimitives.aabbStride;
}
desc.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL;
desc.NumDescs = geomDescs.size();
desc.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
desc.pGeometryDescs = geomDescs.data();
break;
}
default:
return SLANG_E_INVALID_ARG;
}

return SLANG_OK;
}

#endif // SLANG_RHI_DXR

} // namespace rhi::d3d12
47 changes: 47 additions & 0 deletions src/d3d12/d3d12-acceleration-structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,53 @@ class AccelerationStructureImpl : public AccelerationStructure
virtual SLANG_NO_THROW DeviceAddress SLANG_MCALL getDeviceAddress() override;
};

struct AccelerationStructureInputsBuilder
{
D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS desc = {};
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo = {};
std::vector<D3D12_RAYTRACING_GEOMETRY_DESC> geomDescs;
Result build(const AccelerationStructureBuildDesc& buildDesc, IDebugCallback* callback);

private:
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS translateBuildFlags(AccelerationStructureBuildFlags flags)
{
static_assert(
uint32_t(AccelerationStructureBuildFlags::None) == D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_NONE
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::AllowUpdate) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_UPDATE
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::AllowCompaction) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_COMPACTION
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::PreferFastTrace) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::PreferFastBuild) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD
);
static_assert(
uint32_t(AccelerationStructureBuildFlags::MinimizeMemory) ==
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_MINIMIZE_MEMORY
);
return (D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS)flags;
}
D3D12_RAYTRACING_GEOMETRY_FLAGS translateGeometryFlags(AccelerationStructureGeometryFlags flags)
{
static_assert(uint32_t(AccelerationStructureGeometryFlags::None) == D3D12_RAYTRACING_GEOMETRY_FLAG_NONE);
static_assert(uint32_t(AccelerationStructureGeometryFlags::Opaque) == D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE);
static_assert(
uint32_t(AccelerationStructureGeometryFlags::NoDuplicateAnyHitInvocation) ==
D3D12_RAYTRACING_GEOMETRY_FLAG_NO_DUPLICATE_ANYHIT_INVOCATION
);
return (D3D12_RAYTRACING_GEOMETRY_FLAGS)flags;
}
};

#endif // SLANG_RHI_DXR

} // namespace rhi::d3d12
2 changes: 1 addition & 1 deletion src/d3d12/d3d12-command-encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,7 @@ void RayTracingPassEncoderImpl::buildAccelerationStructure(
buildDesc.DestAccelerationStructureData = dstImpl->getDeviceAddress();
buildDesc.SourceAccelerationStructureData = srcImpl ? srcImpl->getDeviceAddress() : 0;
buildDesc.ScratchAccelerationStructureData = scratchBuffer.buffer->getDeviceAddress() + scratchBuffer.offset;
D3DAccelerationStructureInputsBuilder builder;
AccelerationStructureInputsBuilder builder;
builder.build(desc, m_device->m_debugCallback);
buildDesc.Inputs = builder.desc;

Expand Down
2 changes: 1 addition & 1 deletion src/d3d12/d3d12-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ Result DeviceImpl::getAccelerationStructureSizes(
if (!m_device5)
return SLANG_E_NOT_AVAILABLE;

D3DAccelerationStructureInputsBuilder inputsBuilder;
AccelerationStructureInputsBuilder inputsBuilder;
SLANG_RETURN_ON_FAIL(inputsBuilder.build(desc, m_debugCallback));

D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO prebuildInfo;
Expand Down
Loading

0 comments on commit 6386188

Please sign in to comment.