Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
alelenv committed Jan 31, 2025
1 parent 1e6c84e commit 14e7ec7
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 46 deletions.
22 changes: 20 additions & 2 deletions source/val/validate_ray_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ namespace spvtools {
namespace val {
namespace {

uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) {
assert(array_type->opcode() == spv::Op::OpTypeArray);
uint32_t const_int_id = array_type->GetOperandAs<uint32_t>(2U);
Instruction* array_length_inst = _.FindDef(const_int_id);
assert(array_length_inst->opcode() == spv::Op::OpConstant);
uint32_t array_length = array_length_inst->GetOperandAs<uint32_t>(2);
return array_length;
}

spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
const Instruction* inst,
uint32_t ray_query_index) {
Expand Down Expand Up @@ -265,6 +274,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {

case spv::Op::OpRayQueryGetIntersectionSpherePositionNV: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;

if (!_.IsFloatVectorType(result_type) ||
_.GetDimension(result_type) != 3 ||
Expand All @@ -279,21 +289,27 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {

case spv::Op::OpRayQueryGetIntersectionLSSPositionsNV: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;

if ((_.FindDef(result_type)->opcode() != spv::Op::OpTypeArray) ||
auto result_id = _.FindDef(result_type);
if ((result_id->opcode() != spv::Op::OpTypeArray) ||
(GetArrayLength(_, result_id) != 2) ||
!_.IsFloatVectorType(_.GetComponentType(result_type)) ||
_.GetDimension(_.GetComponentType(result_type)) != 3) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 2 element array of 32-bit 3 component float point vector as Result Type: "
<< "Expected 2 element array of 32-bit 3 component float point "
"vector as Result Type: "
<< spvOpcodeString(opcode);
}
break;
}

case spv::Op::OpRayQueryGetIntersectionLSSRadiiNV: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;

if (!_.IsFloatArrayType(result_type) ||
(GetArrayLength(_, _.FindDef(result_type)) != 2) ||
!_.IsFloatScalarType(_.GetComponentType(result_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 32-bit floating point scalar as Result Type: "
Expand All @@ -305,6 +321,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpRayQueryGetIntersectionSphereRadiusNV:
case spv::Op::OpRayQueryGetIntersectionLSSHitValueNV: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;

if (!_.IsFloatScalarType(result_type) ||
_.GetBitWidth(result_type) != 32) {
Expand All @@ -319,6 +336,7 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpRayQueryIsSphereHitNV:
case spv::Op::OpRayQueryIsLSSHitNV: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;

if (!_.IsBoolScalarType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
Expand Down
26 changes: 21 additions & 5 deletions source/val/validate_ray_tracing_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ namespace val {

static const uint32_t KRayParamInvalidId = std::numeric_limits<uint32_t>::max();

uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) {
assert(array_type->opcode() == spv::Op::OpTypeArray);
uint32_t const_int_id = array_type->GetOperandAs<uint32_t>(2U);
Instruction* array_length_inst = _.FindDef(const_int_id);
assert(array_length_inst->opcode() == spv::Op::OpConstant);
uint32_t array_length = array_length_inst->GetOperandAs<uint32_t>(2);
return array_length;
}

spv_result_t ValidateHitObjectPointer(ValidationState_t& _,
const Instruction* inst,
uint32_t hit_object_index) {
Expand Down Expand Up @@ -625,7 +634,8 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
_.GetDimension(result_type) != 3 ||
_.GetBitWidth(result_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 32-bit floating point 2 component vector type as Result Type: "
<< "Expected 32-bit floating point 2 component vector type as "
"Result Type: "
<< spvOpcodeString(opcode);
}
break;
Expand All @@ -635,7 +645,8 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
RegisterOpcodeForValidModel(_, inst);
if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;

if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
if (!_.IsFloatScalarType(result_type) ||
_.GetBitWidth(result_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 32-bit floating point scalar as Result Type: "
<< spvOpcodeString(opcode);
Expand All @@ -647,11 +658,14 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
RegisterOpcodeForValidModel(_, inst);
if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;

if ((_.FindDef(result_type)->opcode() != spv::Op::OpTypeArray) ||
auto result_id = _.FindDef(result_type);
if ((result_id->opcode() != spv::Op::OpTypeArray) ||
(GetArrayLength(_, result_id) != 2) ||
!_.IsFloatVectorType(_.GetComponentType(result_type)) ||
_.GetDimension(_.GetComponentType(result_type)) != 3) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 2 element array of 32-bit 3 component float point vector as Result Type: "
<< "Expected 2 element array of 32-bit 3 component float point "
"vector as Result Type: "
<< spvOpcodeString(opcode);
}
break;
Expand All @@ -662,9 +676,11 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;

if (!_.IsFloatArrayType(result_type) ||
(GetArrayLength(_, _.FindDef(result_type)) != 2) ||
!_.IsFloatScalarType(_.GetComponentType(result_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 2 element array of 32-bit floating point scalar as Result Type: "
<< "Expected 2 element array of 32-bit floating point scalar as "
"Result Type: "
<< spvOpcodeString(opcode);
}
break;
Expand Down
40 changes: 19 additions & 21 deletions test/val/val_ray_query_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ using ::testing::Values;

using ValidateRayQuery = spvtest::ValidateBase<bool>;

std::string GenerateShaderCode(
const std::string& body,
const std::string& capabilities = "",
const std::string& extensions = "",
const std::string& declarations = "") {
std::string GenerateShaderCode(const std::string& body,
const std::string& capabilities = "",
const std::string& extensions = "",
const std::string& declarations = "") {
std::ostringstream ss;
ss << R"(
OpCapability Shader
Expand All @@ -43,7 +42,7 @@ OpCapability Float64
OpCapability RayQueryKHR
)";
ss << capabilities;
ss << R"(
ss << R"(
OpExtension "SPV_KHR_ray_query"
)";

Expand Down Expand Up @@ -95,8 +94,7 @@ OpDecorate %top_level_as Binding 0
%f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0
%f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0
%ptr_rq = OpTypePointer Private %type_rq
%ray_query = OpVariable %ptr_rq Private
%ptr_rq = OpTypePointer Function %type_rq
%ptr_as = OpTypePointer UniformConstant %type_as
%top_level_as = OpVariable %ptr_as UniformConstant
Expand All @@ -111,6 +109,7 @@ OpDecorate %top_level_as Binding 0
ss << R"(
%main = OpFunction %void None %func
%main_entry = OpLabel
%ray_query = OpVariable %ptr_rq Function
)";

ss << body;
Expand Down Expand Up @@ -406,7 +405,7 @@ OpFunctionEnd
OpRayQueryInitializeKHR %rq_param %as_2 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
)";

CompileSuccessfully(GenerateShaderCode(body,"", "", declaration).c_str());
CompileSuccessfully(GenerateShaderCode(body, "", "", declaration).c_str());
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}

Expand Down Expand Up @@ -637,7 +636,6 @@ TEST_F(ValidateRayQuery, RayQueryArraySuccess) {
using RayQueryLSSNVCommon = spvtest::ValidateBase<std::string>;

std::string RayQueryLSSNVResultType(std::string opcode, bool valid) {

if (opcode.compare("OpRayQueryGetIntersectionLSSPositionsNV") == 0)
return valid ? "%arr2v3" : "%f64";

Expand Down Expand Up @@ -676,19 +674,19 @@ TEST_P(RayQueryLSSNVCommon, Success) {
ss << RayQueryLSSNVResultType(opcode, true);
ss << " %ray_query ";
ss << " %s32_0 ";
CompileSuccessfully(GenerateShaderCode(ss.str(), cap, ext).c_str());
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
CompileSuccessfully(GenerateShaderCode(ss.str(), cap, ext).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

INSTANTIATE_TEST_SUITE_P(
ValidateRayQueryLSSNVCommon, RayQueryLSSNVCommon,
Values("OpRayQueryGetIntersectionSpherePositionNV",
"OpRayQueryGetIntersectionLSSPositionsNV",
"OpRayQueryGetIntersectionSphereRadiusNV",
"OpRayQueryGetIntersectionLSSRadiiNV",
"OpRayQueryGetIntersectionLSSHitValueNV",
"OpRayQueryIsSphereHitNV",
"OpRayQueryIsLSSHitNV"));
INSTANTIATE_TEST_SUITE_P(ValidateRayQueryLSSNVCommon, RayQueryLSSNVCommon,
Values("OpRayQueryGetIntersectionSpherePositionNV",
"OpRayQueryGetIntersectionLSSPositionsNV",
"OpRayQueryGetIntersectionSphereRadiusNV",
"OpRayQueryGetIntersectionLSSRadiiNV",
"OpRayQueryGetIntersectionLSSHitValueNV",
"OpRayQueryIsSphereHitNV",
"OpRayQueryIsLSSHitNV"));
} // namespace
} // namespace val
} // namespace spvtools
36 changes: 18 additions & 18 deletions test/val/val_ray_tracing_reorder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ TEST_F(ValidateRayTracingReorderNV,
}

TEST_F(ValidateRayTracingReorderNV, LSSGetSpherePositionNV) {

const std::string cap = R"(
OpCapability RayTracingSpheresGeometryNV
)";
Expand All @@ -623,13 +622,13 @@ TEST_F(ValidateRayTracingReorderNV, LSSGetSpherePositionNV) {
OpStore %pos %result
)";

CompileSuccessfully(GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
CompileSuccessfully(
GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

TEST_F(ValidateRayTracingReorderNV, LSSGetLSSPositionsNV) {

const std::string cap = R"(
OpCapability RayTracingSpheresGeometryNV
OpCapability RayTracingLinearSweptSpheresGeometryNV
Expand All @@ -654,13 +653,13 @@ TEST_F(ValidateRayTracingReorderNV, LSSGetLSSPositionsNV) {
OpStore %lsspos %result
)";

CompileSuccessfully(GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
CompileSuccessfully(
GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

TEST_F(ValidateRayTracingReorderNV, LSSGetSphereRadiusNV) {

const std::string cap = R"(
OpCapability RayTracingSpheresGeometryNV
)";
Expand All @@ -680,13 +679,13 @@ TEST_F(ValidateRayTracingReorderNV, LSSGetSphereRadiusNV) {
OpStore %rad %result
)";

CompileSuccessfully(GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
CompileSuccessfully(
GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

TEST_F(ValidateRayTracingReorderNV, LSSGetLSSRadiiNV) {

const std::string cap = R"(
OpCapability RayTracingLinearSweptSpheresGeometryNV
)";
Expand All @@ -709,13 +708,13 @@ TEST_F(ValidateRayTracingReorderNV, LSSGetLSSRadiiNV) {
OpStore %rad %result
)";

CompileSuccessfully(GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
CompileSuccessfully(
GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

TEST_F(ValidateRayTracingReorderNV, LSSIsSphereHitNV) {

const std::string cap = R"(
OpCapability RayTracingSpheresGeometryNV
)";
Expand All @@ -735,13 +734,13 @@ TEST_F(ValidateRayTracingReorderNV, LSSIsSphereHitNV) {
OpStore %ishit %result
)";

CompileSuccessfully(GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
CompileSuccessfully(
GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

TEST_F(ValidateRayTracingReorderNV, LSSIsLSSHitNV) {

const std::string cap = R"(
OpCapability RayTracingLinearSweptSpheresGeometryNV
)";
Expand All @@ -761,8 +760,9 @@ TEST_F(ValidateRayTracingReorderNV, LSSIsLSSHitNV) {
OpStore %ishit %result
)";

CompileSuccessfully(GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
CompileSuccessfully(
GenerateReorderThreadCode(body, declarations, ext, cap).c_str(),
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}
} // namespace
Expand Down

0 comments on commit 14e7ec7

Please sign in to comment.