Skip to content

Commit

Permalink
Implement callable shader
Browse files Browse the repository at this point in the history
Signed-off-by: kevyuu <[email protected]>
  • Loading branch information
kevyuu committed Jan 22, 2025
1 parent ebd31ef commit bcbd729
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 41 deletions.
12 changes: 6 additions & 6 deletions 71_RayTracingPipeline/app_resources/common.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ struct SPushConstants

struct RayLight
{
float32_t3 inHitPosition;
float32_t3 inHitPosition;
float32_t outLightDistance;
float32_t3 outLightDir;
float32_t3 outLightDir;
float32_t outIntensity;
};

#ifdef __HLSL_VERSION

struct [raypayload] ColorPayload
{
float32_t3 hitValue;
uint32_t seed;
float32_t3 hitValue : read(caller) : write(closesthit,miss);
uint32_t seed : read(closesthit,anyhit) : write(caller);
};

struct [raypayload] ShadowPayload
{
bool isShadowed;
uint32_t seed;
bool isShadowed : read(caller) : write(caller,miss);
uint32_t seed : read(anyhit) : write(caller);
};

enum ObjectType : uint32_t // matches c++
Expand Down
16 changes: 16 additions & 0 deletions 71_RayTracingPipeline/app_resources/lgiht_spot.rcall.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "common.hlsl"

[[vk::push_constant]] SPushConstants pc;

[shader("callable")]
void main(inout RayLight cLight)
{
float32_t3 lDir = pc.light.position - cLight.inHitPosition;
cLight.outLightDistance = length(lDir);
cLight.outIntensity = pc.light.intensity / (cLight.outLightDistance * cLight.outLightDistance);
cLight.outLightDir = normalize(lDir);
float theta = dot(cLight.outLightDir, normalize(-pc.light.direction));
float epsilon = pc.light.innerCutoff - pc.light.outerCutoff;
float spotIntensity = clamp((theta - pc.light.outerCutoff) / epsilon, 0.0, 1.0);
cLight.outIntensity *= spotIntensity;
}
11 changes: 11 additions & 0 deletions 71_RayTracingPipeline/app_resources/light_directional.rcall.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "common.hlsl"

[[vk::push_constant]] SPushConstants pc;

[shader("callable")]
void main(inout RayLight cLight)
{
cLight.outLightDir = normalize(-pc.light.direction);
cLight.outIntensity = 1.0;
cLight.outLightDistance = 10000000;
}
13 changes: 13 additions & 0 deletions 71_RayTracingPipeline/app_resources/light_point.rcall.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "common.hlsl"

[[vk::push_constant]] SPushConstants pc;

[shader("callable")]
void main(inout RayLight cLight)
{
float32_t3 lDir = pc.light.position - cLight.inHitPosition;
float lightDistance = length(lDir);
cLight.outIntensity = pc.light.intensity / (lightDistance * lightDistance);
cLight.outLightDir = normalize(lDir);
cLight.outLightDistance = lightDistance;
}
3 changes: 2 additions & 1 deletion 71_RayTracingPipeline/app_resources/raytrace.rahit.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ void main(inout AnyHitPayload p, in BuiltInTriangleIntersectionAttributes attrib
if (geom.material.illum != 4)
return;

uint32_t seed = p.seed;
if (geom.material.dissolve == 0.0)
IgnoreHit();
else if (rnd(p.seed) > geom.material.dissolve)
else if (rnd(seed) > geom.material.dissolve)
IgnoreHit();
}
30 changes: 3 additions & 27 deletions 71_RayTracingPipeline/app_resources/raytrace.rchit.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -123,31 +123,7 @@ void main(inout ColorPayload p, in BuiltInTriangleIntersectionAttributes attribs

RayLight cLight;
cLight.inHitPosition = worldPosition;
if (pc.light.type == 0)
{
cLight.outLightDir = normalize(-pc.light.direction);
cLight.outIntensity = 1.0;
cLight.outLightDistance = 10000000;
}
if (pc.light.type == 1)
{
float32_t3 lDir = pc.light.position - cLight.inHitPosition;
float lightDistance = length(lDir);
cLight.outIntensity = pc.light.intensity / (lightDistance * lightDistance);
cLight.outLightDir = normalize(lDir);
cLight.outLightDistance = lightDistance;
}
else if (pc.light.type == 2)
{
float32_t3 lDir = pc.light.position - cLight.inHitPosition;
cLight.outLightDistance = length(lDir);
cLight.outIntensity = pc.light.intensity / (cLight.outLightDistance * cLight.outLightDistance);
cLight.outLightDir = normalize(lDir);
float theta = dot(cLight.outLightDir, normalize(-pc.light.direction));
float epsilon = pc.light.innerCutoff - pc.light.outerCutoff;
float spotIntensity = clamp((theta - pc.light.outerCutoff) / epsilon, 0.0, 1.0);
cLight.outIntensity *= spotIntensity;
}
CallShader(pc.light.type, cLight);

float32_t3 diffuse = computeDiffuse(geom.material, cLight.outLightDir, worldNormal);
float32_t3 specular = float32_t3(0, 0, 0);
Expand All @@ -166,9 +142,9 @@ void main(inout ColorPayload p, in BuiltInTriangleIntersectionAttributes attribs
shadowPayload.isShadowed = true;
shadowPayload.seed = p.seed;
TraceRay(topLevelAS, flags, 0xFF, 1, 0, 1, rayDesc, shadowPayload);
p.seed = shadowPayload.seed;

if (shadowPayload.isShadowed)
bool isShadowed = shadowPayload.isShadowed;
if (isShadowed)
{
attenuation = 0.3;
}
Expand Down
1 change: 0 additions & 1 deletion 71_RayTracingPipeline/app_resources/raytrace.rgen.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ void main()

ColorPayload payload;
payload.seed = seed;
payload.hitValue = float32_t3(0, 0, 0);
TraceRay(topLevelAS, RAY_FLAG_NONE, 0xff, 0, 0, 0, rayDesc, payload);

hitValues += payload.hitValue;
Expand Down
21 changes: 15 additions & 6 deletions 71_RayTracingPipeline/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class RaytracingPipelineApp final : public examples::SimpleWindowedApplication,
const auto anyHitShaderShadowPayload = compileShader("app_resources/raytrace.rahit.hlsl", "#define USE_SHADOW_PAYLOAD\n");
const auto missShader = compileShader("app_resources/raytrace.rmiss.hlsl");
const auto shadowMissShader = compileShader("app_resources/raytraceShadow.rmiss.hlsl");
const auto directionalLightCallShader = compileShader("app_resources/light_directional.rcall.hlsl");
const auto pointLightCallShader = compileShader("app_resources/light_point.rcall.hlsl");
const auto spotLightCallShader = compileShader("app_resources/light_spot.rcall.hlsl");

m_semaphore = m_device->createSemaphore(m_realFrameIx);
if (!m_semaphore)
Expand Down Expand Up @@ -275,22 +278,28 @@ class RaytracingPipelineApp final : public examples::SimpleWindowedApplication,

const IGPUShader::SSpecInfo shaders[] = {
{.shader = raygenShader.get()},
{.shader = missShader.get()},
{.shader = shadowMissShader.get()},
{.shader = closestHitShader.get()},
{.shader = anyHitShaderColorPayload.get()},
{.shader = anyHitShaderShadowPayload.get()},
{.shader = missShader.get()},
{.shader = shadowMissShader.get()},
{.shader = directionalLightCallShader.get()},
{.shader = pointLightCallShader.get()},
{.shader = spotLightCallShader.get()},
};

params.layout = pipelineLayout.get();
params.shaders = std::span(shaders, std::size(shaders));
params.cached.shaderGroups.raygenGroup = {
.shaderIndex = 0,
};
params.cached.shaderGroups.hitGroups.push_back({ .closestHitShaderIndex = 1, .anyHitShaderIndex = 2 });
params.cached.shaderGroups.hitGroups.push_back({ .closestHitShaderIndex = 1, .anyHitShaderIndex = 3 });
params.cached.shaderGroups.missGroups.push_back({ .shaderIndex = 4 });
params.cached.shaderGroups.missGroups.push_back({ .shaderIndex = 5 });
params.cached.shaderGroups.missGroups.push_back({ .shaderIndex = 1 });
params.cached.shaderGroups.missGroups.push_back({ .shaderIndex = 2 });
params.cached.shaderGroups.hitGroups.push_back({ .closestHitShaderIndex = 3, .anyHitShaderIndex = 4 });
params.cached.shaderGroups.hitGroups.push_back({ .closestHitShaderIndex = 3, .anyHitShaderIndex = 5 });
params.cached.shaderGroups.callableGroups.push_back({.shaderIndex = 6});
params.cached.shaderGroups.callableGroups.push_back({.shaderIndex = 7});
params.cached.shaderGroups.callableGroups.push_back({.shaderIndex = 8});
params.cached.maxRecursionDepth = 2;
if (!m_device->createRayTracingPipelines(nullptr, { &params, 1 }, &m_rayTracingPipeline))
return logFail("Failed to create ray tracing pipeline");
Expand Down

0 comments on commit bcbd729

Please sign in to comment.