From 3364b982713a0440d1d342dd5eec65b122a61b71 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 29 Jan 2025 17:19:56 -0600 Subject: [PATCH] Validate SPV_NV_cooperative_vector (#5972) --- DEPS | 2 +- include/spirv-tools/libspirv.h | 3 + source/opcode.cpp | 2 + source/operand.cpp | 6 + source/opt/eliminate_dead_members_pass.cpp | 5 + source/opt/folding_rules.cpp | 5 +- source/opt/type_manager.cpp | 26 ++ source/opt/types.cpp | 32 ++ source/opt/types.h | 27 ++ source/val/validate_arithmetics.cpp | 61 ++- source/val/validate_bitwise.cpp | 15 +- source/val/validate_composites.cpp | 30 +- source/val/validate_constants.cpp | 16 +- source/val/validate_conversion.cpp | 109 +++-- source/val/validate_extensions.cpp | 33 +- source/val/validate_memory.cpp | 478 +++++++++++++++++++-- source/val/validate_type.cpp | 54 +++ source/val/validation_state.cpp | 54 +++ source/val/validation_state.h | 7 + test/opt/type_manager_test.cpp | 1 + test/val/val_arithmetics_test.cpp | 161 +++++++ test/val/val_composites_test.cpp | 190 ++++++++ test/val/val_conversion_test.cpp | 127 ++++++ test/val/val_memory_test.cpp | 461 ++++++++++++++++++++ 24 files changed, 1810 insertions(+), 95 deletions(-) diff --git a/DEPS b/DEPS index 8d51a29796..5fcfe0c8a9 100644 --- a/DEPS +++ b/DEPS @@ -14,7 +14,7 @@ vars = { 're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca', - 'spirv_headers_revision': '2b2e05e088841c63c0b6fd4c9fb380d8688738d3', + 'spirv_headers_revision': '767e901c986e9755a17e7939b3046fc2911a4bbd', } deps = { diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index ecf763f338..2ac36a6a1f 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -324,6 +324,9 @@ typedef enum spv_operand_type_t { SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS, SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS, + SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT, + SPV_OPERAND_TYPE_COMPONENT_TYPE, + // This is a sentinel value, and does not represent an operand type. // It should come last. SPV_OPERAND_TYPE_NUM_OPERAND_TYPES, diff --git a/source/opcode.cpp b/source/opcode.cpp index f515c3dea3..985c91c365 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -303,6 +303,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) { case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: return true; default: return false; @@ -381,6 +382,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) { case spv::Op::OpTypeAccelerationStructureNV: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: // case spv::Op::OpTypeAccelerationStructureKHR: covered by // spv::Op::OpTypeAccelerationStructureNV case spv::Op::OpTypeRayQueryKHR: diff --git a/source/operand.cpp b/source/operand.cpp index 853a5672c5..869a7ca137 100644 --- a/source/operand.cpp +++ b/source/operand.cpp @@ -300,6 +300,10 @@ const char* spvOperandTypeStr(spv_operand_type_t type) { return "quantization mode"; case SPV_OPERAND_TYPE_OVERFLOW_MODES: return "overflow mode"; + case SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT: + return "cooperative vector matrix layout"; + case SPV_OPERAND_TYPE_COMPONENT_TYPE: + return "component type"; case SPV_OPERAND_TYPE_NONE: return "NONE"; @@ -399,6 +403,8 @@ bool spvOperandIsConcrete(spv_operand_type_t type) { case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS: case SPV_OPERAND_TYPE_FPENCODING: case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE: + case SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT: + case SPV_OPERAND_TYPE_COMPONENT_TYPE: return true; default: break; diff --git a/source/opt/eliminate_dead_members_pass.cpp b/source/opt/eliminate_dead_members_pass.cpp index e440296ffd..170f27068b 100644 --- a/source/opt/eliminate_dead_members_pass.cpp +++ b/source/opt/eliminate_dead_members_pass.cpp @@ -207,6 +207,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract( case spv::Op::OpTypeMatrix: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: type_id = type_inst->GetSingleWordInOperand(0); break; default: @@ -255,6 +256,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain( case spv::Op::OpTypeMatrix: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: type_id = type_inst->GetSingleWordInOperand(0); break; default: @@ -516,6 +518,7 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) { case spv::Op::OpTypeMatrix: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: new_operands.emplace_back(inst->GetInOperand(i)); type_id = type_inst->GetSingleWordInOperand(0); break; @@ -591,6 +594,7 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { case spv::Op::OpTypeMatrix: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: type_id = type_inst->GetSingleWordInOperand(0); break; default: @@ -654,6 +658,7 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) { case spv::Op::OpTypeMatrix: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: type_id = type_inst->GetSingleWordInOperand(0); break; default: diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 5748f97188..e5ac2a1cb2 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -77,7 +77,10 @@ int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) { // Returns the element width of |type|. uint32_t ElementWidth(const analysis::Type* type) { - if (const analysis::Vector* vec_type = type->AsVector()) { + if (const analysis::CooperativeVectorNV* coopvec_type = + type->AsCooperativeVectorNV()) { + return ElementWidth(coopvec_type->component_type()); + } else if (const analysis::Vector* vec_type = type->AsVector()) { return ElementWidth(vec_type->element_type()); } else if (const analysis::Float* float_type = type->AsFloat()) { return float_type->width(); diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index c3850f7fbb..45d14af531 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -476,6 +476,20 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { 0, id, operands); break; } + case Type::kCooperativeVectorNV: { + auto coop_vec = type->AsCooperativeVectorNV(); + uint32_t const component_type = + GetTypeInstruction(coop_vec->component_type()); + if (component_type == 0) { + return 0; + } + typeInst = MakeUnique( + context(), spv::Op::OpTypeCooperativeVectorNV, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {component_type}}, + {SPV_OPERAND_TYPE_ID, {coop_vec->components()}}}); + break; + } default: assert(false && "Unexpected type"); break; @@ -721,6 +735,14 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) { tv_type->dim_id(), tv_type->has_dimensions_id(), tv_type->perm()); break; } + case Type::kCooperativeVectorNV: { + const CooperativeVectorNV* cv_type = type.AsCooperativeVectorNV(); + const Type* component_type = cv_type->component_type(); + rebuilt_ty = MakeUnique( + RebuildType(GetId(component_type), *component_type), + cv_type->components()); + break; + } default: assert(false && "Unhandled type"); return nullptr; @@ -970,6 +992,10 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4)); break; + case spv::Op::OpTypeCooperativeVectorNV: + type = new CooperativeVectorNV(GetType(inst.GetSingleWordInOperand(0)), + inst.GetSingleWordInOperand(1)); + break; case spv::Op::OpTypeRayQueryKHR: type = new RayQueryKHR(); break; diff --git a/source/opt/types.cpp b/source/opt/types.cpp index 39f9bd9e51..2023719c86 100644 --- a/source/opt/types.cpp +++ b/source/opt/types.cpp @@ -132,6 +132,7 @@ std::unique_ptr Type::Clone() const { DeclareKindCase(AccelerationStructureNV); DeclareKindCase(CooperativeMatrixNV); DeclareKindCase(CooperativeMatrixKHR); + DeclareKindCase(CooperativeVectorNV); DeclareKindCase(RayQueryKHR); DeclareKindCase(HitObjectNV); #undef DeclareKindCase @@ -181,6 +182,7 @@ bool Type::operator==(const Type& other) const { DeclareKindCase(AccelerationStructureNV); DeclareKindCase(CooperativeMatrixNV); DeclareKindCase(CooperativeMatrixKHR); + DeclareKindCase(CooperativeVectorNV); DeclareKindCase(RayQueryKHR); DeclareKindCase(HitObjectNV); DeclareKindCase(TensorLayoutNV); @@ -240,6 +242,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const { DeclareKindCase(AccelerationStructureNV); DeclareKindCase(CooperativeMatrixNV); DeclareKindCase(CooperativeMatrixKHR); + DeclareKindCase(CooperativeVectorNV); DeclareKindCase(RayQueryKHR); DeclareKindCase(HitObjectNV); DeclareKindCase(TensorLayoutNV); @@ -835,6 +838,35 @@ bool TensorViewNV::IsSameImpl(const Type* that, IsSameCache*) const { has_dimensions_id_ == tv->has_dimensions_id_ && perm_ == tv->perm_; } +CooperativeVectorNV::CooperativeVectorNV(const Type* type, + const uint32_t components) + : Type(kCooperativeVectorNV), + component_type_(type), + components_(components) { + assert(type != nullptr); + assert(components != 0); +} + +std::string CooperativeVectorNV::str() const { + std::ostringstream oss; + oss << "<" << component_type_->str() << ", " << components_ << ">"; + return oss.str(); +} + +size_t CooperativeVectorNV::ComputeExtraStateHash(size_t hash, + SeenTypes* seen) const { + hash = hash_combine(hash, components_); + return component_type_->ComputeHashValue(hash, seen); +} + +bool CooperativeVectorNV::IsSameImpl(const Type* that, + IsSameCache* seen) const { + const CooperativeVectorNV* mt = that->AsCooperativeVectorNV(); + if (!mt) return false; + return component_type_->IsSameImpl(mt->component_type_, seen) && + components_ == mt->components_ && HasSameDecorations(that); +} + } // namespace analysis } // namespace opt } // namespace spvtools diff --git a/source/opt/types.h b/source/opt/types.h index f9e3bb1e9e..1418331369 100644 --- a/source/opt/types.h +++ b/source/opt/types.h @@ -64,6 +64,7 @@ class NamedBarrier; class AccelerationStructureNV; class CooperativeMatrixNV; class CooperativeMatrixKHR; +class CooperativeVectorNV; class RayQueryKHR; class HitObjectNV; class TensorLayoutNV; @@ -108,6 +109,7 @@ class Type { kAccelerationStructureNV, kCooperativeMatrixNV, kCooperativeMatrixKHR, + kCooperativeVectorNV, kRayQueryKHR, kHitObjectNV, kTensorLayoutNV, @@ -213,6 +215,7 @@ class Type { DeclareCastMethod(AccelerationStructureNV) DeclareCastMethod(CooperativeMatrixNV) DeclareCastMethod(CooperativeMatrixKHR) + DeclareCastMethod(CooperativeVectorNV) DeclareCastMethod(RayQueryKHR) DeclareCastMethod(HitObjectNV) DeclareCastMethod(TensorLayoutNV) @@ -742,6 +745,30 @@ class TensorViewNV : public Type { std::vector perm_; }; +class CooperativeVectorNV : public Type { + public: + CooperativeVectorNV(const Type* type, const uint32_t components); + CooperativeVectorNV(const CooperativeVectorNV&) = default; + + std::string str() const override; + + CooperativeVectorNV* AsCooperativeVectorNV() override { return this; } + const CooperativeVectorNV* AsCooperativeVectorNV() const override { + return this; + } + + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; + + const Type* component_type() const { return component_type_; } + uint32_t components() const { return components_; } + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* component_type_; + const uint32_t components_; +}; + #define DefineParameterlessType(type, name) \ class type : public Type { \ public: \ diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp index 8b0049c5b4..d252ec92b4 100644 --- a/source/val/validate_arithmetics.cpp +++ b/source/val/validate_arithmetics.cpp @@ -40,19 +40,34 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { bool supportsCoopMat = (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFRem && opcode != spv::Op::OpFMod); + bool supportsCoopVec = + (opcode != spv::Op::OpFRem && opcode != spv::Op::OpFMod); if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type) && !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) && !(opcode == spv::Op::OpFMul && _.IsCooperativeMatrixKHRType(result_type) && - _.IsFloatCooperativeMatrixType(result_type))) + _.IsFloatCooperativeMatrixType(result_type)) && + !(supportsCoopVec && _.IsFloatCooperativeVectorNVType(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected floating scalar or vector type as Result Type: " << spvOpcodeString(opcode); for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) { + if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) { + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); + if (!_.IsCooperativeVectorNVType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to be of Result Type: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + } + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, type_id, result_type); + if (ret != SPV_SUCCESS) return ret; + } else if (supportsCoopMat && + _.IsCooperativeMatrixKHRType(result_type)) { const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!_.IsCooperativeMatrixKHRType(type_id) || !_.IsFloatCooperativeMatrixType(type_id)) { @@ -76,17 +91,32 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpUDiv: case spv::Op::OpUMod: { bool supportsCoopMat = (opcode == spv::Op::OpUDiv); + bool supportsCoopVec = (opcode == spv::Op::OpUDiv); if (!_.IsUnsignedIntScalarType(result_type) && !_.IsUnsignedIntVectorType(result_type) && !(supportsCoopMat && - _.IsUnsignedIntCooperativeMatrixType(result_type))) + _.IsUnsignedIntCooperativeMatrixType(result_type)) && + !(supportsCoopVec && + _.IsUnsignedIntCooperativeVectorNVType(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); for (size_t operand_index = 2; operand_index < inst->operands().size(); ++operand_index) { - if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) { + if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) { + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); + if (!_.IsCooperativeVectorNVType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to be of Result Type: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + } + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, type_id, result_type); + if (ret != SPV_SUCCESS) return ret; + } else if (supportsCoopMat && + _.IsCooperativeMatrixKHRType(result_type)) { const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!_.IsCooperativeMatrixKHRType(type_id) || !_.IsUnsignedIntCooperativeMatrixType(type_id)) { @@ -117,11 +147,14 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { bool supportsCoopMat = (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem && opcode != spv::Op::OpSMod); + bool supportsCoopVec = + (opcode != spv::Op::OpSRem && opcode != spv::Op::OpSMod); if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) && !(opcode == spv::Op::OpIMul && _.IsCooperativeMatrixKHRType(result_type) && - _.IsIntCooperativeMatrixType(result_type))) + _.IsIntCooperativeMatrixType(result_type)) && + !(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -133,6 +166,18 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { ++operand_index) { const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); + if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) { + if (!_.IsCooperativeVectorNVType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to be of Result Type: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + } + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, type_id, result_type); + if (ret != SPV_SUCCESS) return ret; + } + if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) { if (!_.IsCooperativeMatrixKHRType(type_id) || !_.IsIntCooperativeMatrixType(type_id)) { @@ -151,7 +196,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) && !(opcode == spv::Op::OpIMul && _.IsCooperativeMatrixKHRType(result_type) && - _.IsIntCooperativeMatrixType(result_type)))) + _.IsIntCooperativeMatrixType(result_type)) && + !(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type)))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as operand: " << spvOpcodeString(opcode) << " operand index " @@ -210,7 +256,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { } case spv::Op::OpVectorTimesScalar: { - if (!_.IsFloatVectorType(result_type)) + if (!_.IsFloatVectorType(result_type) && + !_.IsFloatCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float vector type as Result Type: " << spvOpcodeString(opcode); diff --git a/source/val/validate_bitwise.cpp b/source/val/validate_bitwise.cpp index d8d995814d..bb0588a09f 100644 --- a/source/val/validate_bitwise.cpp +++ b/source/val/validate_bitwise.cpp @@ -64,7 +64,8 @@ spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpShiftRightLogical: case spv::Op::OpShiftRightArithmetic: case spv::Op::OpShiftLeftLogical: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && + !_.IsIntCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -74,7 +75,8 @@ spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { const uint32_t shift_type = _.GetOperandTypeId(inst, 3); if (!base_type || - (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type))) + (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type) && + !_.IsIntCooperativeVectorNVType(base_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base to be int scalar or vector: " << spvOpcodeString(opcode); @@ -90,7 +92,8 @@ spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { << "as Result Type: " << spvOpcodeString(opcode); if (!shift_type || - (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type))) + (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type) && + !_.IsIntCooperativeVectorNVType(shift_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Shift to be int scalar or vector: " << spvOpcodeString(opcode); @@ -106,7 +109,8 @@ spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpBitwiseXor: case spv::Op::OpBitwiseAnd: case spv::Op::OpNot: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && + !_.IsIntCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -118,7 +122,8 @@ spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { ++operand_index) { const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); if (!type_id || - (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id))) + (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) && + !_.IsIntCooperativeVectorNVType(type_id))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector as operand: " << spvOpcodeString(opcode) << " operand index " diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp index c08eb2d2e8..2afeae78ad 100644 --- a/source/val/validate_composites.cpp +++ b/source/val/validate_composites.cpp @@ -125,6 +125,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _, *member_type = type_inst->word(component_index + 2); break; } + case spv::Op::OpTypeCooperativeVectorNV: case spv::Op::OpTypeCooperativeMatrixKHR: case spv::Op::OpTypeCooperativeMatrixNV: { *member_type = type_inst->word(2); @@ -151,7 +152,8 @@ spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _, const uint32_t vector_type = _.GetOperandTypeId(inst, 2); const spv::Op vector_opcode = _.GetIdOpcode(vector_type); - if (vector_opcode != spv::Op::OpTypeVector) { + if (vector_opcode != spv::Op::OpTypeVector && + vector_opcode != spv::Op::OpTypeCooperativeVectorNV) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Vector type to be OpTypeVector"; } @@ -179,7 +181,8 @@ spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _, const Instruction* inst) { const uint32_t result_type = inst->type_id(); const spv::Op result_opcode = _.GetIdOpcode(result_type); - if (result_opcode != spv::Op::OpTypeVector) { + if (result_opcode != spv::Op::OpTypeVector && + result_opcode != spv::Op::OpTypeCooperativeVectorNV) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Result Type to be OpTypeVector"; } @@ -217,14 +220,24 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _, const uint32_t result_type = inst->type_id(); const spv::Op result_opcode = _.GetIdOpcode(result_type); switch (result_opcode) { - case spv::Op::OpTypeVector: { - const uint32_t num_result_components = _.GetDimension(result_type); + case spv::Op::OpTypeVector: + case spv::Op::OpTypeCooperativeVectorNV: { + uint32_t num_result_components = _.GetDimension(result_type); const uint32_t result_component_type = _.GetComponentType(result_type); uint32_t given_component_count = 0; - if (num_operands <= 3) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected number of constituents to be at least 2"; + bool comp_is_int32 = true, comp_is_const_int32 = true; + + if (result_opcode == spv::Op::OpTypeVector) { + if (num_operands <= 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of constituents to be at least 2"; + } + } else { + uint32_t comp_count_id = + _.FindDef(result_type)->GetOperandAs(2); + std::tie(comp_is_int32, comp_is_const_int32, num_result_components) = + _.EvalInt32IfConst(comp_count_id); } for (uint32_t operand_index = 2; operand_index < num_operands; @@ -244,7 +257,8 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _, } } - if (num_result_components != given_component_count) { + if (comp_is_const_int32 && + num_result_components != given_component_count) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected total number of given components to be equal " << "to the size of Result Type vector"; diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp index 1d40eedf91..9c689c5393 100644 --- a/source/val/validate_constants.cpp +++ b/source/val/validate_constants.cpp @@ -46,9 +46,18 @@ spv_result_t ValidateConstantComposite(ValidationState_t& _, const auto constituent_count = inst->words().size() - 3; switch (result_type->opcode()) { - case spv::Op::OpTypeVector: { - const auto component_count = result_type->GetOperandAs(2); - if (component_count != constituent_count) { + case spv::Op::OpTypeVector: + case spv::Op::OpTypeCooperativeVectorNV: { + uint32_t num_result_components = _.GetDimension(result_type->id()); + bool comp_is_int32 = true, comp_is_const_int32 = true; + + if (result_type->opcode() == spv::Op::OpTypeCooperativeVectorNV) { + uint32_t comp_count_id = result_type->GetOperandAs(2); + std::tie(comp_is_int32, comp_is_const_int32, num_result_components) = + _.EvalInt32IfConst(comp_count_id); + } + + if (comp_is_const_int32 && num_result_components != constituent_count) { // TODO: Output ID's on diagnostic return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name @@ -312,6 +321,7 @@ bool IsTypeNullable(const std::vector& instruction, case spv::Op::OpTypeMatrix: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: case spv::Op::OpTypeVector: { auto base_type = _.FindDef(instruction[2]); return base_type && IsTypeNullable(base_type->words(), _); diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp index 770b8e2e33..c459ec384c 100644 --- a/source/val/validate_conversion.cpp +++ b/source/val/validate_conversion.cpp @@ -33,7 +33,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpConvertFToU: { if (!_.IsUnsignedIntScalarType(result_type) && !_.IsUnsignedIntVectorType(result_type) && - !_.IsUnsignedIntCooperativeMatrixType(result_type)) + !_.IsUnsignedIntCooperativeMatrixType(result_type) && + !_.IsUnsignedIntCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -41,13 +42,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type) && - !_.IsFloatCooperativeMatrixType(input_type))) + !_.IsFloatCooperativeMatrixType(input_type) && + !_.IsFloatCooperativeVectorNVType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); - if (_.IsCooperativeMatrixType(result_type) || - _.IsCooperativeMatrixType(input_type)) { + if (_.IsCooperativeVectorNVType(result_type) || + _.IsCooperativeVectorNVType(input_type)) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; @@ -63,7 +70,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpConvertFToS: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && - !_.IsIntCooperativeMatrixType(result_type)) + !_.IsIntCooperativeMatrixType(result_type) && + !_.IsIntCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -71,13 +79,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type) && - !_.IsFloatCooperativeMatrixType(input_type))) + !_.IsFloatCooperativeMatrixType(input_type) && + !_.IsFloatCooperativeVectorNVType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); - if (_.IsCooperativeMatrixType(result_type) || - _.IsCooperativeMatrixType(input_type)) { + if (_.IsCooperativeVectorNVType(result_type) || + _.IsCooperativeVectorNVType(input_type)) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; @@ -95,7 +109,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpConvertUToF: { if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type) && - !_.IsFloatCooperativeMatrixType(result_type)) + !_.IsFloatCooperativeMatrixType(result_type) && + !_.IsFloatCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -103,13 +118,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) && - !_.IsIntCooperativeMatrixType(input_type))) + !_.IsIntCooperativeMatrixType(input_type) && + !_.IsIntCooperativeVectorNVType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); - if (_.IsCooperativeMatrixType(result_type) || - _.IsCooperativeMatrixType(input_type)) { + if (_.IsCooperativeVectorNVType(result_type) || + _.IsCooperativeVectorNVType(input_type)) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; @@ -126,7 +147,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpUConvert: { if (!_.IsUnsignedIntScalarType(result_type) && !_.IsUnsignedIntVectorType(result_type) && - !_.IsUnsignedIntCooperativeMatrixType(result_type)) + !_.IsUnsignedIntCooperativeMatrixType(result_type) && + !_.IsUnsignedIntCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected unsigned int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -134,13 +156,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) && - !_.IsIntCooperativeMatrixType(input_type))) + !_.IsIntCooperativeMatrixType(input_type) && + !_.IsIntCooperativeVectorNVType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); - if (_.IsCooperativeMatrixType(result_type) || - _.IsCooperativeMatrixType(input_type)) { + if (_.IsCooperativeVectorNVType(result_type) || + _.IsCooperativeVectorNVType(input_type)) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; @@ -161,7 +189,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpSConvert: { if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && - !_.IsIntCooperativeMatrixType(result_type)) + !_.IsIntCooperativeMatrixType(result_type) && + !_.IsIntCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected int scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -169,13 +198,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) && - !_.IsIntCooperativeMatrixType(input_type))) + !_.IsIntCooperativeMatrixType(input_type) && + !_.IsIntCooperativeVectorNVType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be int scalar or vector: " << spvOpcodeString(opcode); - if (_.IsCooperativeMatrixType(result_type) || - _.IsCooperativeMatrixType(input_type)) { + if (_.IsCooperativeVectorNVType(result_type) || + _.IsCooperativeVectorNVType(input_type)) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; @@ -197,7 +232,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpFConvert: { if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type) && - !_.IsFloatCooperativeMatrixType(result_type)) + !_.IsFloatCooperativeMatrixType(result_type) && + !_.IsFloatCooperativeVectorNVType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected float scalar or vector type as Result Type: " << spvOpcodeString(opcode); @@ -205,13 +241,19 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const uint32_t input_type = _.GetOperandTypeId(inst, 2); if (!input_type || (!_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type) && - !_.IsFloatCooperativeMatrixType(input_type))) + !_.IsFloatCooperativeMatrixType(input_type) && + !_.IsFloatCooperativeVectorNVType(input_type))) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be float scalar or vector: " << spvOpcodeString(opcode); - if (_.IsCooperativeMatrixType(result_type) || - _.IsCooperativeMatrixType(input_type)) { + if (_.IsCooperativeVectorNVType(result_type) || + _.IsCooperativeVectorNVType(input_type)) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } else if (_.IsCooperativeMatrixType(result_type) || + _.IsCooperativeMatrixType(input_type)) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, true); if (ret != SPV_SUCCESS) return ret; @@ -475,9 +517,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type); const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type); + const bool result_is_coopvec = _.IsCooperativeVectorNVType(result_type); + const bool input_is_coopvec = _.IsCooperativeVectorNVType(input_type); if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat && - !_.IsIntVectorType(result_type) && + !result_is_coopvec && !_.IsIntVectorType(result_type) && !_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -485,17 +529,28 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { << "or scalar type: " << spvOpcodeString(opcode); if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat && - !_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) && - !_.IsFloatVectorType(input_type)) + !input_is_coopvec && !_.IsIntVectorType(input_type) && + !_.IsFloatScalarType(input_type) && !_.IsFloatVectorType(input_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected input to be a pointer or int or float vector " << "or scalar: " << spvOpcodeString(opcode); + if (result_is_coopvec != input_is_coopvec) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Cooperative vector can only be cast to another cooperative " + << "vector: " << spvOpcodeString(opcode); + if (result_is_coopmat != input_is_coopmat) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Cooperative matrix can only be cast to another cooperative " << "matrix: " << spvOpcodeString(opcode); + if (result_is_coopvec) { + spv_result_t ret = + _.CooperativeVectorDimensionsMatch(inst, result_type, input_type); + if (ret != SPV_SUCCESS) return ret; + } + if (result_is_coopmat) { spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type, input_type, false); diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp index c2d8d59358..283ebc41a0 100644 --- a/source/val/validate_extensions.cpp +++ b/source/val/validate_extensions.cpp @@ -1136,7 +1136,16 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { case GLSLstd450NMin: case GLSLstd450NMax: case GLSLstd450NClamp: { - if (!_.IsFloatScalarOrVectorType(result_type)) { + bool supportsCoopVec = + (ext_inst_key == GLSLstd450FMin || ext_inst_key == GLSLstd450FMax || + ext_inst_key == GLSLstd450FClamp || + ext_inst_key == GLSLstd450NMin || ext_inst_key == GLSLstd450NMax || + ext_inst_key == GLSLstd450NClamp || + ext_inst_key == GLSLstd450Step || ext_inst_key == GLSLstd450Fma); + + if (!_.IsFloatScalarOrVectorType(result_type) && + !(supportsCoopVec && + _.IsFloatCooperativeVectorNVType(result_type))) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a float scalar or vector type"; @@ -1166,7 +1175,14 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { case GLSLstd450FindILsb: case GLSLstd450FindUMsb: case GLSLstd450FindSMsb: { - if (!_.IsIntScalarOrVectorType(result_type)) { + bool supportsCoopVec = + (ext_inst_key == GLSLstd450UMin || ext_inst_key == GLSLstd450UMax || + ext_inst_key == GLSLstd450UClamp || + ext_inst_key == GLSLstd450SMin || ext_inst_key == GLSLstd450SMax || + ext_inst_key == GLSLstd450SClamp); + + if (!_.IsIntScalarOrVectorType(result_type) && + !(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type))) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be an int scalar or vector type"; @@ -1178,7 +1194,10 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { for (uint32_t operand_index = 4; operand_index < num_operands; ++operand_index) { const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); - if (!operand_type || !_.IsIntScalarOrVectorType(operand_type)) { + if (!operand_type || + (!_.IsIntScalarOrVectorType(operand_type) && + !(supportsCoopVec && + _.IsIntCooperativeVectorNVType(operand_type)))) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected all operands to be int scalars or vectors"; @@ -1231,7 +1250,13 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { case GLSLstd450Log2: case GLSLstd450Atan2: case GLSLstd450Pow: { - if (!_.IsFloatScalarOrVectorType(result_type)) { + bool supportsCoopVec = + (ext_inst_key == GLSLstd450Atan || ext_inst_key == GLSLstd450Tanh || + ext_inst_key == GLSLstd450Exp || ext_inst_key == GLSLstd450Log); + + if (!_.IsFloatScalarOrVectorType(result_type) && + !(supportsCoopVec && + _.IsFloatCooperativeVectorNVType(result_type))) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Result Type to be a 16 or 32-bit scalar or " diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index c589528013..8996a067e7 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -196,37 +196,6 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage, return false; } -bool ContainsCooperativeMatrix(ValidationState_t& _, - const Instruction* storage) { - const size_t elem_type_index = 1; - uint32_t elem_type_id; - Instruction* elem_type; - - switch (storage->opcode()) { - case spv::Op::OpTypeCooperativeMatrixNV: - case spv::Op::OpTypeCooperativeMatrixKHR: - return true; - case spv::Op::OpTypeArray: - case spv::Op::OpTypeRuntimeArray: - elem_type_id = storage->GetOperandAs(elem_type_index); - elem_type = _.FindDef(elem_type_id); - return ContainsCooperativeMatrix(_, elem_type); - case spv::Op::OpTypeStruct: - for (size_t member_type_index = 1; - member_type_index < storage->operands().size(); - ++member_type_index) { - auto member_type_id = - storage->GetOperandAs(member_type_index); - auto member_type = _.FindDef(member_type_id); - if (ContainsCooperativeMatrix(_, member_type)) return true; - } - break; - default: - break; - } - return false; -} - std::pair GetStorageClass( ValidationState_t& _, const Instruction* inst) { spv::StorageClass dst_sc = spv::StorageClass::Max; @@ -235,6 +204,7 @@ std::pair GetStorageClass( case spv::Op::OpCooperativeMatrixLoadNV: case spv::Op::OpCooperativeMatrixLoadTensorNV: case spv::Op::OpCooperativeMatrixLoadKHR: + case spv::Op::OpCooperativeVectorLoadNV: case spv::Op::OpLoad: { auto load_pointer = _.FindDef(inst->GetOperandAs(2)); auto load_pointer_type = _.FindDef(load_pointer->type_id()); @@ -244,6 +214,7 @@ std::pair GetStorageClass( case spv::Op::OpCooperativeMatrixStoreNV: case spv::Op::OpCooperativeMatrixStoreTensorNV: case spv::Op::OpCooperativeMatrixStoreKHR: + case spv::Op::OpCooperativeVectorStoreNV: case spv::Op::OpStore: { auto store_pointer = _.FindDef(inst->GetOperandAs(0)); auto store_pointer_type = _.FindDef(store_pointer->type_id()); @@ -280,8 +251,9 @@ int MemoryAccessNumWords(uint32_t mask) { // Returns the scope ID operand for MakeAvailable memory access with mask // at the given operand index. // This function is only called for OpLoad, OpStore, OpCopyMemory and -// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and -// OpCooperativeMatrixStoreNV. +// OpCopyMemorySized, OpCooperativeMatrixLoadNV, +// OpCooperativeMatrixStoreNV, OpCooperativeVectorLoadNV, +// OpCooperativeVectorStoreNV. uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask, uint32_t mask_index) { assert(mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)); @@ -292,8 +264,9 @@ uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask, } // This function is only called for OpLoad, OpStore, OpCopyMemory, -// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and -// OpCooperativeMatrixStoreNV. +// OpCopyMemorySized, OpCooperativeMatrixLoadNV, +// OpCooperativeMatrixStoreNV, OpCooperativeVectorLoadNV, +// OpCooperativeVectorStoreNV. uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask, uint32_t mask_index) { assert(mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)); @@ -333,7 +306,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, if (inst->opcode() == spv::Op::OpLoad || inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV || inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV || - inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) { + inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR || + inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerAvailableKHR cannot be used with OpLoad."; } @@ -354,7 +328,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, if (inst->opcode() == spv::Op::OpStore || inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV || inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR || - inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) { + inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV || + inst->opcode() == spv::Op::OpCooperativeVectorStoreNV) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerVisibleKHR cannot be used with OpStore."; } @@ -821,7 +796,12 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) { // Cooperative matrix types can only be allocated in Function or Private if ((storage_class != spv::StorageClass::Function && storage_class != spv::StorageClass::Private) && - pointee && ContainsCooperativeMatrix(_, pointee)) { + pointee && + _.ContainsType(pointee->id(), [](const Instruction* type_inst) { + auto opcode = type_inst->opcode(); + return opcode == spv::Op::OpTypeCooperativeMatrixNV || + opcode == spv::Op::OpTypeCooperativeMatrixKHR; + })) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Cooperative matrix types (or types containing them) can only be " "allocated " @@ -829,6 +809,20 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) { "parameters"; } + if ((storage_class != spv::StorageClass::Function && + storage_class != spv::StorageClass::Private) && + pointee && + _.ContainsType(pointee->id(), [](const Instruction* type_inst) { + auto opcode = type_inst->opcode(); + return opcode == spv::Op::OpTypeCooperativeVectorNV; + })) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Cooperative vector types (or types containing them) can only be " + "allocated " + << "in Function or Private storage classes or as function " + "parameters"; + } + if (_.HasCapability(spv::Capability::Shader)) { // Don't allow variables containing 16-bit elements without the appropriate // capabilities. @@ -1597,13 +1591,15 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, switch (type_pointee->opcode()) { case spv::Op::OpTypeMatrix: case spv::Op::OpTypeVector: + case spv::Op::OpTypeCooperativeVectorNV: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: case spv::Op::OpTypeArray: case spv::Op::OpTypeRuntimeArray: case spv::Op::OpTypeNodePayloadArrayAMDX: { // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV, - // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type. + // OpTypeCooperativeVectorNV, OpTypeArray, and OpTypeRuntimeArray, word + // 2 is the Element Type. type_pointee = _.FindDef(type_pointee->word(2)); break; } @@ -2401,6 +2397,392 @@ spv_result_t ValidateCooperativeMatrixLoadStoreTensorNV( return SPV_SUCCESS; } +spv_result_t ValidateInt32Operand(ValidationState_t& _, const Instruction* inst, + uint32_t operand_index, + const char* opcode_name, + const char* operand_name) { + const auto type_id = + _.FindDef(inst->GetOperandAs(operand_index))->type_id(); + if (!_.IsIntScalarType(type_id) || _.GetBitWidth(type_id) != 32) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " " << operand_name << " type " + << _.getIdName(type_id) << " is not a 32 bit integer."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCooperativeVectorPointer(ValidationState_t& _, + const Instruction* inst, + const char* opname, + uint32_t pointer_index) { + const auto pointer_id = inst->GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + ((_.addressing_model() == spv::AddressingModel::Logical) && + ((!_.features().variable_pointers && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (_.features().variable_pointers && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Pointer " << _.getIdName(pointer_id) + << " is not a logical pointer."; + } + + const auto pointer_type_id = pointer->type_id(); + const auto pointer_type = _.FindDef(pointer_type_id); + if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " type for pointer " << _.getIdName(pointer_id) + << " is not a pointer type."; + } + + const auto storage_class_index = 1u; + const auto storage_class = + pointer_type->GetOperandAs(storage_class_index); + + if (storage_class != spv::StorageClass::Workgroup && + storage_class != spv::StorageClass::StorageBuffer && + storage_class != spv::StorageClass::PhysicalStorageBuffer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " storage class for pointer type " + << _.getIdName(pointer_type_id) + << " is not Workgroup or StorageBuffer."; + } + + const auto pointee_id = pointer_type->GetOperandAs(2); + const auto pointee_type = _.FindDef(pointee_id); + if (!pointee_type || + (pointee_type->opcode() != spv::Op::OpTypeArray && + pointee_type->opcode() != spv::Op::OpTypeRuntimeArray)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Pointer " << _.getIdName(pointer->id()) + << "s Type must be an array type."; + } + + const auto array_elem_type_id = pointee_type->GetOperandAs(1); + auto array_elem_type = _.FindDef(array_elem_type_id); + if (!array_elem_type || !(_.IsIntScalarOrVectorType(array_elem_type_id) || + _.IsFloatScalarOrVectorType(array_elem_type_id))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opname << " Pointer " << _.getIdName(pointer->id()) + << "s Type must be an array of scalar or vector type."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateCooperativeVectorLoadStoreNV(ValidationState_t& _, + const Instruction* inst) { + uint32_t type_id; + const char* opname; + if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) { + type_id = inst->type_id(); + opname = "spv::Op::OpCooperativeVectorLoadNV"; + } else { + // get Object operand's type + type_id = _.FindDef(inst->GetOperandAs(2))->type_id(); + opname = "spv::Op::OpCooperativeVectorStoreNV"; + } + + auto vector_type = _.FindDef(type_id); + + if (vector_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) { + if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "spv::Op::OpCooperativeVectorLoadNV Result Type " + << _.getIdName(type_id) << " is not a cooperative vector type."; + } else { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "spv::Op::OpCooperativeVectorStoreNV Object type " + << _.getIdName(type_id) << " is not a cooperative vector type."; + } + } + + const auto pointer_index = + (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) ? 2u : 0u; + + if (auto error = + ValidateCooperativeVectorPointer(_, inst, opname, pointer_index)) { + return error; + } + + const auto memory_access_index = + (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) ? 4u : 3u; + if (inst->operands().size() > memory_access_index) { + if (auto error = CheckMemoryAccess(_, inst, memory_access_index)) + return error; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateCooperativeVectorOuterProductNV(ValidationState_t& _, + const Instruction* inst) { + const auto pointer_index = 0u; + const auto opcode_name = + "spv::Op::OpCooperativeVectorOuterProductAccumulateNV"; + + if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name, + pointer_index)) { + return error; + } + + auto type_id = _.FindDef(inst->GetOperandAs(2))->type_id(); + auto a_type = _.FindDef(type_id); + + if (a_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " A type " << _.getIdName(type_id) + << " is not a cooperative vector type."; + } + + type_id = _.FindDef(inst->GetOperandAs(3))->type_id(); + auto b_type = _.FindDef(type_id); + + if (b_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " B type " << _.getIdName(type_id) + << " is not a cooperative vector type."; + } + + const auto a_component_type_id = a_type->GetOperandAs(1); + const auto b_component_type_id = b_type->GetOperandAs(1); + + if (a_component_type_id != b_component_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " A and B component types " + << _.getIdName(a_component_type_id) << " and " + << _.getIdName(b_component_type_id) << " do not match."; + } + + if (auto error = ValidateInt32Operand(_, inst, 1, opcode_name, "Offset")) { + return error; + } + + if (auto error = + ValidateInt32Operand(_, inst, 4, opcode_name, "MemoryLayout")) { + return error; + } + + if (auto error = ValidateInt32Operand(_, inst, 5, opcode_name, + "MatrixInterpretation")) { + return error; + } + + if (inst->operands().size() > 6) { + if (auto error = + ValidateInt32Operand(_, inst, 6, opcode_name, "MatrixStride")) { + return error; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateCooperativeVectorReduceSumNV(ValidationState_t& _, + const Instruction* inst) { + const auto opcode_name = "spv::Op::OpCooperativeVectorReduceSumAccumulateNV"; + const auto pointer_index = 0u; + + if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name, + pointer_index)) { + return error; + } + + auto type_id = _.FindDef(inst->GetOperandAs(2))->type_id(); + auto v_type = _.FindDef(type_id); + + if (v_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " V type " << _.getIdName(type_id) + << " is not a cooperative vector type."; + } + + if (auto error = ValidateInt32Operand(_, inst, 1, opcode_name, "Offset")) { + return error; + } + + return SPV_SUCCESS; +} + +bool InterpretationIsPacked(spv::ComponentType interp) { + switch (interp) { + case spv::ComponentType::SignedInt8PackedNV: + case spv::ComponentType::UnsignedInt8PackedNV: + return true; + default: + return false; + } +} + +using std::get; + +spv_result_t ValidateCooperativeVectorMatrixMulNV(ValidationState_t& _, + const Instruction* inst) { + const bool has_bias = + inst->opcode() == spv::Op::OpCooperativeVectorMatrixMulAddNV; + const auto opcode_name = has_bias + ? "spv::Op::OpCooperativeVectorMatrixMulAddNV" + : "spv::Op::OpCooperativeVectorMatrixMulNV"; + + const auto bias_offset = has_bias ? 3 : 0; + + const auto result_type_index = 0u; + const auto input_index = 2u; + const auto input_interpretation_index = 3u; + const auto matrix_index = 4u; + const auto matrix_interpretation_index = 6u; + const auto bias_index = 7u; + const auto bias_interpretation_index = 9u; + const auto m_index = 7u + bias_offset; + const auto k_index = 8u + bias_offset; + const auto memory_layout_index = 9u + bias_offset; + const auto transpose_index = 10u + bias_offset; + + const auto result_type_id = inst->GetOperandAs(result_type_index); + const auto input_id = inst->GetOperandAs(input_index); + const auto input_interpretation_id = + inst->GetOperandAs(input_interpretation_index); + const auto matrix_interpretation_id = + inst->GetOperandAs(matrix_interpretation_index); + const auto bias_interpretation_id = + inst->GetOperandAs(bias_interpretation_index); + const auto m_id = inst->GetOperandAs(m_index); + const auto k_id = inst->GetOperandAs(k_index); + const auto memory_layout_id = + inst->GetOperandAs(memory_layout_index); + const auto transpose_id = inst->GetOperandAs(transpose_index); + + if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name, + matrix_index)) { + return error; + } + + if (inst->opcode() == spv::Op::OpCooperativeVectorMatrixMulAddNV) { + if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name, + bias_index)) { + return error; + } + } + + const auto result_type = _.FindDef(result_type_id); + + if (result_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " result type " << _.getIdName(result_type_id) + << " is not a cooperative vector type."; + } + + const auto result_component_type_id = result_type->GetOperandAs(1u); + if (!(_.IsIntScalarType(result_component_type_id) && + _.GetBitWidth(result_component_type_id) == 32) && + !(_.IsFloatScalarType(result_component_type_id) && + (_.GetBitWidth(result_component_type_id) == 32 || + _.GetBitWidth(result_component_type_id) == 16))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " result component type " + << _.getIdName(result_component_type_id) + << " is not a 32 bit int or 16/32 bit float."; + } + + const auto m_eval = _.EvalInt32IfConst(m_id); + const auto rc_eval = + _.EvalInt32IfConst(result_type->GetOperandAs(2u)); + if (get<1>(m_eval) && get<1>(rc_eval) && get<2>(m_eval) != get<2>(rc_eval)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " result type number of components " + << get<2>(rc_eval) << " does not match M " << get<2>(m_eval); + } + + const auto k_eval = _.EvalInt32IfConst(k_id); + + const auto input = _.FindDef(input_id); + const auto input_type = _.FindDef(input->type_id()); + const auto input_num_components_id = input_type->GetOperandAs(2u); + + auto input_interp_eval = _.EvalInt32IfConst(input_interpretation_id); + if (get<1>(input_interp_eval) && + !InterpretationIsPacked(spv::ComponentType{get<2>(input_interp_eval)})) { + const auto inc_eval = _.EvalInt32IfConst(input_num_components_id); + if (get<1>(inc_eval) && get<1>(k_eval) && + get<2>(inc_eval) != get<2>(k_eval)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " input number of components " + << get<2>(inc_eval) << " does not match K " << get<2>(k_eval); + } + } + + if (!_.IsBoolScalarType(_.FindDef(transpose_id)->type_id())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Transpose " << _.getIdName(transpose_id) + << " is not a scalar boolean."; + } + + const auto check_constant = [&](uint32_t id, + const char* operand_name) -> spv_result_t { + if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " " << operand_name << " " + << _.getIdName(id) << " is not a constant instruction."; + } + return SPV_SUCCESS; + }; + + if (auto error = + check_constant(input_interpretation_id, "InputInterpretation")) { + return error; + } + if (auto error = + check_constant(matrix_interpretation_id, "MatrixInterpretation")) { + return error; + } + if (has_bias) { + if (auto error = + check_constant(bias_interpretation_id, "BiasInterpretation")) { + return error; + } + } + if (auto error = check_constant(m_id, "M")) { + return error; + } + if (auto error = check_constant(k_id, "K")) { + return error; + } + if (auto error = check_constant(memory_layout_id, "MemoryLayout")) { + return error; + } + if (auto error = check_constant(transpose_id, "Transpose")) { + return error; + } + + if (auto error = ValidateInt32Operand(_, inst, input_interpretation_index, + opcode_name, "InputInterpretation")) { + return error; + } + if (auto error = ValidateInt32Operand(_, inst, matrix_interpretation_index, + opcode_name, "MatrixInterpretation")) { + return error; + } + if (has_bias) { + if (auto error = ValidateInt32Operand(_, inst, bias_interpretation_index, + opcode_name, "BiasInterpretation")) { + return error; + } + } + if (auto error = ValidateInt32Operand(_, inst, m_index, opcode_name, "M")) { + return error; + } + if (auto error = ValidateInt32Operand(_, inst, k_index, opcode_name, "K")) { + return error; + } + if (auto error = ValidateInt32Operand(_, inst, memory_layout_index, + opcode_name, "MemoryLayout")) { + return error; + } + + return SPV_SUCCESS; +} + spv_result_t ValidatePtrComparison(ValidationState_t& _, const Instruction* inst) { if (_.addressing_model() == spv::AddressingModel::Logical && @@ -2514,6 +2896,24 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) { if (auto error = ValidateCooperativeMatrixLoadStoreTensorNV(_, inst)) return error; break; + case spv::Op::OpCooperativeVectorLoadNV: + case spv::Op::OpCooperativeVectorStoreNV: + if (auto error = ValidateCooperativeVectorLoadStoreNV(_, inst)) + return error; + break; + case spv::Op::OpCooperativeVectorOuterProductAccumulateNV: + if (auto error = ValidateCooperativeVectorOuterProductNV(_, inst)) + return error; + break; + case spv::Op::OpCooperativeVectorReduceSumAccumulateNV: + if (auto error = ValidateCooperativeVectorReduceSumNV(_, inst)) + return error; + break; + case spv::Op::OpCooperativeVectorMatrixMulNV: + case spv::Op::OpCooperativeVectorMatrixMulAddNV: + if (auto error = ValidateCooperativeVectorMatrixMulNV(_, inst)) + return error; + break; case spv::Op::OpPtrEqual: case spv::Op::OpPtrNotEqual: case spv::Op::OpPtrDiff: diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index 4cd7684604..c2aca99768 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp @@ -169,6 +169,57 @@ spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) { return SPV_SUCCESS; } +spv_result_t ValidateTypeCooperativeVectorNV(ValidationState_t& _, + const Instruction* inst) { + const auto component_index = 1; + const auto component_type_id = inst->GetOperandAs(component_index); + const auto component_type = _.FindDef(component_type_id); + if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() && + spv::Op::OpTypeInt != component_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeVectorNV Component Type " + << _.getIdName(component_type_id) + << " is not a scalar numerical type."; + } + + const auto num_components_index = 2; + const auto num_components_id = + inst->GetOperandAs(num_components_index); + const auto num_components = _.FindDef(num_components_id); + if (!num_components || !spvOpcodeIsConstant(num_components->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeVectorNV component count " + << _.getIdName(num_components_id) + << " is not a scalar constant type."; + } + + // NOTE: Check the initialiser value of the constant + const auto const_inst = num_components->words(); + const auto const_result_type_index = 1; + const auto const_result_type = _.FindDef(const_inst[const_result_type_index]); + if (!const_result_type || spv::Op::OpTypeInt != const_result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeVectorNV component count " + << _.getIdName(num_components_id) + << " is not a constant integer type."; + } + + int64_t num_components_value; + if (_.EvalConstantValInt64(num_components_id, &num_components_value)) { + auto& type_words = const_result_type->words(); + const bool is_signed = type_words[3] > 0; + if (num_components_value == 0 || (num_components_value < 0 && is_signed)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeCooperativeVectorNV component count " + << _.getIdName(num_components_id) + << " default value must be at least 1: found " + << num_components_value; + } + } + + return SPV_SUCCESS; +} + spv_result_t ValidateTypeMatrix(ValidationState_t& _, const Instruction* inst) { const auto column_type_index = 1; const auto column_type_id = inst->GetOperandAs(column_type_index); @@ -831,6 +882,9 @@ spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpTypeCooperativeMatrixKHR: if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error; break; + case spv::Op::OpTypeCooperativeVectorNV: + if (auto error = ValidateTypeCooperativeVectorNV(_, inst)) return error; + break; case spv::Op::OpTypeUntypedPointerKHR: if (auto error = ValidateTypeUntypedPointerKHR(_, inst)) return error; break; diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 2da95ebaea..398f9b5a38 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -883,6 +883,7 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const { case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: return inst->word(2); default: @@ -911,6 +912,7 @@ uint32_t ValidationState_t::GetDimension(uint32_t id) const { case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: // Actual dimension isn't known, return 0 return 0; @@ -1292,6 +1294,27 @@ bool ValidationState_t::IsUnsigned64BitHandle(uint32_t id) const { GetBitWidth(id) == 32)); } +bool ValidationState_t::IsCooperativeVectorNVType(uint32_t id) const { + const Instruction* inst = FindDef(id); + return inst && inst->opcode() == spv::Op::OpTypeCooperativeVectorNV; +} + +bool ValidationState_t::IsFloatCooperativeVectorNVType(uint32_t id) const { + if (!IsCooperativeVectorNVType(id)) return false; + return IsFloatScalarType(FindDef(id)->word(2)); +} + +bool ValidationState_t::IsIntCooperativeVectorNVType(uint32_t id) const { + if (!IsCooperativeVectorNVType(id)) return false; + return IsIntScalarType(FindDef(id)->word(2)); +} + +bool ValidationState_t::IsUnsignedIntCooperativeVectorNVType( + uint32_t id) const { + if (!IsCooperativeVectorNVType(id)) return false; + return IsUnsignedIntScalarType(FindDef(id)->word(2)); +} + spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( const Instruction* inst, uint32_t result_type_id, uint32_t m2, bool is_conversion, bool swap_row_col) { @@ -1375,6 +1398,36 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch( return SPV_SUCCESS; } +spv_result_t ValidationState_t::CooperativeVectorDimensionsMatch( + const Instruction* inst, uint32_t v1, uint32_t v2) { + const auto v1_type = FindDef(v1); + const auto v2_type = FindDef(v2); + + if (v1_type->opcode() != v2_type->opcode()) { + return diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected cooperative vector types"; + } + + uint32_t v1_components_id = v1_type->GetOperandAs(2); + uint32_t v2_components_id = v2_type->GetOperandAs(2); + + bool v1_is_int32 = false, v1_is_const_int32 = false, v2_is_int32 = false, + v2_is_const_int32 = false; + uint32_t v1_value = 0, v2_value = 0; + + std::tie(v1_is_int32, v1_is_const_int32, v1_value) = + EvalInt32IfConst(v1_components_id); + std::tie(v2_is_int32, v2_is_const_int32, v2_value) = + EvalInt32IfConst(v2_components_id); + + if (v1_is_const_int32 && v2_is_const_int32 && v1_value != v2_value) { + return diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of components to be identical"; + } + + return SPV_SUCCESS; +} + uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst, size_t operand_index) const { return GetTypeId(inst->GetOperandAs(operand_index)); @@ -1667,6 +1720,7 @@ bool ValidationState_t::ContainsType( case spv::Op::OpTypeSampledImage: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: + case spv::Op::OpTypeCooperativeVectorNV: return ContainsType(inst->GetOperandAs(1u), f, traverse_all_types); case spv::Op::OpTypePointer: diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 77e7f43538..cee3d9b2ce 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -662,6 +662,10 @@ class ValidationState_t { bool IsIntCooperativeMatrixType(uint32_t id) const; bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const; bool IsUnsigned64BitHandle(uint32_t id) const; + bool IsCooperativeVectorNVType(uint32_t id) const; + bool IsFloatCooperativeVectorNVType(uint32_t id) const; + bool IsIntCooperativeVectorNVType(uint32_t id) const; + bool IsUnsignedIntCooperativeVectorNVType(uint32_t id) const; // Returns true if |id| is a type id that contains |type| (or integer or // floating point type) of |width| bits. @@ -801,6 +805,9 @@ class ValidationState_t { uint32_t m2, bool is_conversion, bool swap_row_col = false); + spv_result_t CooperativeVectorDimensionsMatch(const Instruction* inst, + uint32_t v1, uint32_t v2); + // Returns true if |lhs| and |rhs| logically match and, if the decorations of // |rhs| are a subset of |lhs|. // diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp index 08566fe7f8..0a8e0c7d8e 100644 --- a/test/opt/type_manager_test.cpp +++ b/test/opt/type_manager_test.cpp @@ -176,6 +176,7 @@ std::vector> GenerateAllTypes() { types.emplace_back(new CooperativeMatrixKHR(f32, 8, 8, 8, 1002)); types.emplace_back(new RayQueryKHR()); types.emplace_back(new HitObjectNV()); + types.emplace_back(new CooperativeVectorNV(f32, 16)); // SPV_AMDX_shader_enqueue types.emplace_back(new NodePayloadArrayAMDX(sts32f32)); diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp index 8b2a8d0b78..42f4ce7d4a 100644 --- a/test/val/val_arithmetics_test.cpp +++ b/test/val/val_arithmetics_test.cpp @@ -1851,6 +1851,167 @@ OpFunctionEnd EXPECT_THAT(getDiagnosticString(), HasSubstr("must be a 32-bit integer")); } +std::string GenerateCoopVecCode(const std::string& extra_types, + const std::string& main_body) { + const std::string prefix = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpCapability ReplicatedCompositesEXT +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_EXT_replicated_composites" +%ext_inst = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%u32_16 = OpConstant %u32 16 +%u32_4 = OpConstant %u32 4 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_8 +%f16vec4 = OpTypeCooperativeVectorNV %f16 %u32_4 +%u32vec = OpTypeCooperativeVectorNV %u32 %u32_8 +%s32vec = OpTypeCooperativeVectorNV %s32 %u32_8 + +%f16_1 = OpConstant %f16 1 +%f32_1 = OpConstant %f32 1 +%u32_1 = OpConstant %u32 1 +%s32_1 = OpConstant %s32 1 + +%f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1 +%f16vec_1 = OpConstantComposite %f16vec %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 +%u32vec_1 = OpConstantComposite %u32vec %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 +%s32vec_1 = OpConstantComposite %s32vec %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 + +%u32_c1 = OpSpecConstant %u32 1 +%u32_c2 = OpSpecConstant %u32 2 + +%f16vecc = OpTypeCooperativeVectorNV %f16 %u32_c1 +%f16vecc_1 = OpConstantCompositeReplicateEXT %f16vecc %f16_1 +)"; + + const std::string func_begin = + R"( +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string suffix = + R"( +OpReturn +OpFunctionEnd)"; + + return prefix + extra_types + func_begin + main_body + suffix; +} + +TEST_F(ValidateArithmetics, CoopVecSuccess) { + const std::string body = R"( +%val1 = OpFAdd %f16vec %f16vec_1 %f16vec_1 +%val2 = OpFSub %f16vec %f16vec_1 %f16vec_1 +%val3 = OpFDiv %f16vec %f16vec_1 %f16vec_1 +%val4 = OpFNegate %f16vec %f16vec_1 +%val5 = OpIAdd %u32vec %u32vec_1 %u32vec_1 +%val6 = OpISub %u32vec %u32vec_1 %u32vec_1 +%val7 = OpUDiv %u32vec %u32vec_1 %u32vec_1 +%val8 = OpIAdd %s32vec %s32vec_1 %s32vec_1 +%val9 = OpISub %s32vec %s32vec_1 %s32vec_1 +%val10 = OpSDiv %s32vec %s32vec_1 %s32vec_1 +%val11 = OpSNegate %s32vec %s32vec_1 +%val12 = OpVectorTimesScalar %f16vec %f16vec_1 %f16_1 +%val13 = OpExtInst %f16vec %ext_inst FMin %f16vec_1 %f16vec_1 +%val14 = OpExtInst %f16vec %ext_inst FMax %f16vec_1 %f16vec_1 +%val15 = OpExtInst %f16vec %ext_inst FClamp %f16vec_1 %f16vec_1 %f16vec_1 +%val16 = OpExtInst %f16vec %ext_inst NClamp %f16vec_1 %f16vec_1 %f16vec_1 +%val17 = OpExtInst %f16vec %ext_inst Step %f16vec_1 %f16vec_1 +%val18 = OpExtInst %f16vec %ext_inst Exp %f16vec_1 +%val19 = OpExtInst %f16vec %ext_inst Log %f16vec_1 +%val20 = OpExtInst %f16vec %ext_inst Tanh %f16vec_1 +%val21 = OpExtInst %f16vec %ext_inst Atan %f16vec_1 +%val22 = OpExtInst %f16vec %ext_inst Fma %f16vec_1 %f16vec_1 %f16vec_1 +%val23 = OpExtInst %u32vec %ext_inst UMin %u32vec_1 %u32vec_1 +%val24 = OpExtInst %u32vec %ext_inst UMax %u32vec_1 %u32vec_1 +%val25 = OpExtInst %u32vec %ext_inst UClamp %u32vec_1 %u32vec_1 %u32vec_1 +%val26 = OpExtInst %s32vec %ext_inst SMin %s32vec_1 %s32vec_1 +%val27 = OpExtInst %s32vec %ext_inst SMax %s32vec_1 %s32vec_1 +%val28 = OpExtInst %s32vec %ext_inst SClamp %s32vec_1 %s32vec_1 %s32vec_1 +%val29 = OpShiftRightLogical %u32vec %u32vec_1 %u32vec_1 +%val30 = OpShiftRightArithmetic %u32vec %u32vec_1 %u32vec_1 +%val31 = OpShiftLeftLogical %u32vec %u32vec_1 %u32vec_1 +%val32 = OpBitwiseOr %u32vec %u32vec_1 %u32vec_1 +%val33 = OpBitwiseXor %u32vec %u32vec_1 %u32vec_1 +%val34 = OpBitwiseAnd %u32vec %u32vec_1 %u32vec_1 +%val35 = OpNot %u32vec %u32vec_1 +)"; + + CompileSuccessfully(GenerateCoopVecCode("", body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, CoopVecFMulPass) { + const std::string body = R"( +%val1 = OpFMul %f16vec %f16vec_1 %f16vec_1 +)"; + + CompileSuccessfully(GenerateCoopVecCode("", body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, CoopVecVectorTimesScalarMismatchFail) { + const std::string body = R"( +%val1 = OpVectorTimesScalar %f16vec %f16vec_1 %f32_1 +)"; + + CompileSuccessfully(GenerateCoopVecCode("", body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected scalar operand type to be equal to the component " + "type of the vector operand: VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, CoopVecDimFail) { + const std::string body = R"( +%val1 = OpFMul %f16vec %f16vec_1 %f16vec4_1 +)"; + + CompileSuccessfully(GenerateCoopVecCode("", body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of components to be identical")); +} + +TEST_F(ValidateArithmetics, CoopVecComponentTypeNotScalarNumeric) { + const std::string types = R"( +%bad = OpTypeCooperativeVectorNV %bool %u32_8 +)"; + + CompileSuccessfully(GenerateCoopVecCode(types, "").c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeCooperativeVectorNV Component Type " + "'5[%bool]' is not a scalar numerical type.")); +} + +TEST_F(ValidateArithmetics, CoopVecDimNotConstantInt) { + const std::string types = R"( +%bad = OpTypeCooperativeVectorNV %f16 %f32_1 +)"; + + CompileSuccessfully(GenerateCoopVecCode(types, "").c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeCooperativeVectorNV component count " + "'19[%float_1]' is not a constant integer type")); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp index 6e0d7c03c7..460ae5873a 100644 --- a/test/val/val_composites_test.cpp +++ b/test/val/val_composites_test.cpp @@ -2102,6 +2102,196 @@ TEST_F(ValidateComposites, CopyObjectVoid) { HasSubstr("OpCopyObject cannot have void result type")); } +TEST_F(ValidateComposites, CoopVecConstantCompositePass) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_16 = OpConstant %u32 16 +%useA = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_16 + +%f16_1 = OpConstant %f16 1 + +%f16vec_1 = OpConstantComposite %f16vec %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CoopVecConstantCompositeMismatchFail) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_16 = OpConstant %u32 16 +%useA = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_16 + +%f32_1 = OpConstant %f32 1 + +%f16vec_1 = OpConstantComposite %f16vec %f32_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent count does not match " + "Result Type '11[%11]'s vector component count")); +} + +TEST_F(ValidateComposites, CoopVecCompositeConstructPass) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_16 = OpConstant %u32 16 +%useA = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_16 + +%f16_1 = OpConstant %f16 1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%f16vec_1 = OpCompositeConstruct %f16vec %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CoopVecCompositeConstructMismatchFail) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_16 = OpConstant %u32 16 +%useA = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_16 + +%f32_1 = OpConstant %f32 1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%f16vec_1 = OpCompositeConstruct %f16vec %f32_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Constituents to be scalars or vectors of the " + "same type as Result Type components")); +} + +TEST_F(ValidateComposites, CoopVecInsertExtractDynamicPass) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%u32_1 = OpConstant %u32 1 +%u32_16 = OpConstant %u32 16 +%useA = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_16 + +%f16_1 = OpConstant %f16 1 +%f16vec_1 = OpConstantComposite %f16vec %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%insert = OpVectorInsertDynamic %f16vec %f16vec_1 %f16_1 %u32_1 +%extract = OpVectorExtractDynamic %f16 %insert %u32_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp index 3869626ef0..69d4045341 100644 --- a/test/val/val_conversion_test.cpp +++ b/test/val/val_conversion_test.cpp @@ -2326,6 +2326,133 @@ OpFunctionEnd)"; "swapped with columns")); } +TEST_F(ValidateConversion, CoopVecConversionSuccess) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeVectorNV +OpCapability ReplicatedCompositesEXT +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_EXT_replicated_composites" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%s16 = OpTypeInt 16 1 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%use_A = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_8 +%f32vec = OpTypeCooperativeVectorNV %f32 %u32_8 +%u16vec = OpTypeCooperativeVectorNV %u16 %u32_8 +%u32vec = OpTypeCooperativeVectorNV %u32 %u32_8 +%s16vec = OpTypeCooperativeVectorNV %s16 %u32_8 +%s32vec = OpTypeCooperativeVectorNV %s32 %u32_8 + +%f16_1 = OpConstant %f16 1 +%f32_1 = OpConstant %f32 1 +%u16_1 = OpConstant %u16 1 +%u32_1 = OpConstant %u32 1 +%s16_1 = OpConstant %s16 1 +%s32_1 = OpConstant %s32 1 + +%f16vec_1 = OpConstantCompositeReplicateEXT %f16vec %f16_1 +%f32vec_1 = OpConstantCompositeReplicateEXT %f32vec %f32_1 +%u16vec_1 = OpConstantCompositeReplicateEXT %u16vec %u16_1 +%u32vec_1 = OpConstantCompositeReplicateEXT %u32vec %u32_1 +%s16vec_1 = OpConstantCompositeReplicateEXT %s16vec %s16_1 +%s32vec_1 = OpConstantCompositeReplicateEXT %s32vec %s32_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val11 = OpConvertFToU %u16vec %f16vec_1 +%val12 = OpConvertFToU %u32vec %f16vec_1 +%val13 = OpConvertFToS %s16vec %f16vec_1 +%val14 = OpConvertFToS %s32vec %f16vec_1 +%val15 = OpFConvert %f32vec %f16vec_1 + +%val21 = OpConvertFToU %u16vec %f32vec_1 +%val22 = OpConvertFToU %u32vec %f32vec_1 +%val23 = OpConvertFToS %s16vec %f32vec_1 +%val24 = OpConvertFToS %s32vec %f32vec_1 +%val25 = OpFConvert %f16vec %f32vec_1 + +%val31 = OpConvertUToF %f16vec %u16vec_1 +%val32 = OpConvertUToF %f32vec %u16vec_1 +%val33 = OpUConvert %u32vec %u16vec_1 +%val34 = OpSConvert %s32vec %u16vec_1 + +%val41 = OpConvertSToF %f16vec %s16vec_1 +%val42 = OpConvertSToF %f32vec %s16vec_1 +%val43 = OpUConvert %u32vec %s16vec_1 +%val44 = OpSConvert %s32vec %s16vec_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, CoopVecConversionDimMismatchFail) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int16 +OpCapability CooperativeVectorNV +OpCapability ReplicatedCompositesEXT +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_EXT_replicated_composites" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%s16 = OpTypeInt 16 1 +%s32 = OpTypeInt 32 1 + +%u32_8 = OpConstant %u32 8 +%u32_4 = OpConstant %u32 4 +%subgroup = OpConstant %u32 3 +%use_A = OpConstant %u32 0 +%use_B = OpConstant %u32 1 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_8 +%f32vec = OpTypeCooperativeVectorNV %f32 %u32_4 + +%f16_1 = OpConstant %f16 1 + +%f16vec_1 = OpConstantCompositeReplicateEXT %f16vec %f16_1 + +%main = OpFunction %void None %func +%main_entry = OpLabel + +%val1 = OpFConvert %f32vec %f16vec_1 + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of components to be identical")); +} } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index 06eaff0951..e4cd470e2b 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -7835,6 +7835,467 @@ OpFunctionEnd EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env)); } +std::string GenCoopVecLoadStoreShader(const std::string& storeMemoryAccess, + const std::string& loadMemoryAccess) { + std::string s = R"( +OpCapability Shader +OpCapability Float16 +OpCapability StorageBuffer16BitAccess +OpCapability VulkanMemoryModel +OpCapability CooperativeVectorNV +OpCapability ReplicatedCompositesEXT +OpExtension "SPV_EXT_replicated_composites" +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_NV_cooperative_vector" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %4 "main" %48 %73 +OpExecutionMode %4 LocalSize 1 1 1 + +OpDecorate %45 ArrayStride 2 +OpDecorate %46 Block +OpMemberDecorate %46 0 Offset 0 +OpDecorate %48 Binding 0 +OpDecorate %48 DescriptorSet 0 + +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 0 +%49 = OpTypeInt 32 1 +%41 = OpTypeFloat 16 + +%14 = OpConstant %6 1 +%50 = OpConstant %49 0 +%82 = OpConstant %6 5 + +%42 = OpTypeCooperativeVectorNV %41 %14 +%43 = OpTypePointer Function %42 + +%45 = OpTypeRuntimeArray %41 +%46 = OpTypeStruct %45 +%47 = OpTypePointer StorageBuffer %46 +%48 = OpVariable %47 StorageBuffer +%51 = OpTypePointer StorageBuffer %45 + +%57 = OpTypePointer Private %42 +%73 = OpVariable %57 Private + +%4 = OpFunction %2 None %3 +%5 = OpLabel +%52 = OpAccessChain %51 %48 %50 +%56 = OpCooperativeVectorLoadNV %42 %52 %50 )" + + loadMemoryAccess + R"( %82 +%77 = OpLoad %42 %73 +OpCooperativeVectorStoreNV %52 %50 %77 )" + storeMemoryAccess + R"( %82 +OpReturn +OpFunctionEnd +)"; + + return s; +} + +TEST_F(ValidateMemory, CoopVecLoadStoreSuccess) { + std::string spirv = + GenCoopVecLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); +} + +TEST_F(ValidateMemory, CoopVecStoreMemoryAccessFail) { + std::string spirv = + GenCoopVecLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MakePointerVisibleKHR cannot be used with OpStore")); +} + +TEST_F(ValidateMemory, CoopVecLoadMemoryAccessFail) { + std::string spirv = + GenCoopVecLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerAvailableKHR|NonPrivatePointerKHR"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad")); +} + +TEST_F(ValidateMemory, CoopVecInvalidStorageClassFail) { + const std::string body = R"( +OpCapability Shader +OpCapability Float16 +OpCapability CooperativeVectorNV +OpCapability ReplicatedCompositesEXT +OpExtension "SPV_NV_cooperative_vector" +OpExtension "SPV_EXT_replicated_composites" +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%f16 = OpTypeFloat 16 +%u32 = OpTypeInt 32 0 + +%u32_8 = OpConstant %u32 8 +%use_A = OpConstant %u32 0 +%subgroup = OpConstant %u32 3 + +%f16vec = OpTypeCooperativeVectorNV %f16 %u32_8 + +%str = OpTypeStruct %f16vec +%str_ptr = OpTypePointer Workgroup %str +%sh = OpVariable %str_ptr Workgroup + +%main = OpFunction %void None %func +%main_entry = OpLabel + +OpReturn +OpFunctionEnd)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Cooperative vector types (or types containing them) can only be " + "allocated in Function or Private storage classes or as function " + "parameters")); +} + +std::string GenCoopVecShader(const std::string& extra_types, + const std::string& main_body) { + const std::string prefix = + R"( +OpCapability Shader +OpCapability Float16 +OpCapability Int64 +OpCapability StorageBuffer16BitAccess +OpCapability VulkanMemoryModel +OpCapability CooperativeVectorNV +OpCapability CooperativeVectorTrainingNV +OpCapability ReplicatedCompositesEXT +OpExtension "SPV_EXT_replicated_composites" +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_NV_cooperative_vector" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %main "main" %48 %73 +OpExecutionMode %main LocalSize 1 1 1 + +OpDecorate %f16_arr ArrayStride 2 +OpDecorate %46 Block +OpMemberDecorate %46 0 Offset 0 +OpDecorate %48 Binding 0 +OpDecorate %48 DescriptorSet 0 + +%void = OpTypeVoid +%func = OpTypeFunction %void +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f16 = OpTypeFloat 16 +%bool = OpTypeBool + +%false = OpConstantFalse %bool +%u32_4 = OpConstant %u32 4 +%u32_8 = OpConstant %u32 8 +%s32_0 = OpConstant %s32 0 +%f16_0 = OpConstant %f16 0 + +%f16vec4 = OpTypeCooperativeVectorNV %f16 %u32_4 +%f16vec8 = OpTypeCooperativeVectorNV %f16 %u32_8 + +%f16_arr = OpTypeRuntimeArray %f16 +%46 = OpTypeStruct %f16_arr +%47 = OpTypePointer StorageBuffer %46 +%48 = OpVariable %47 StorageBuffer +%51 = OpTypePointer StorageBuffer %f16_arr + +%57 = OpTypePointer Private %f16vec4 +%73 = OpVariable %57 Private +%u32ptr = OpTypePointer Function %u32 + +%input4 = OpConstantCompositeReplicateEXT %f16vec4 %f16_0 +%input8 = OpConstantCompositeReplicateEXT %f16vec8 %f16_0 +%interp = OpConstant %u32 0 +%offset = OpConstant %u32 0 + +)"; + + const std::string func_begin = + R"( +%main = OpFunction %void None %func +%main_entry = OpLabel +%u32var = OpVariable %u32ptr Function +%array_ptr = OpAccessChain %51 %48 %s32_0 +)"; + + const std::string suffix = + R"( +OpReturn +OpFunctionEnd)"; + + return prefix + extra_types + func_begin + main_body + suffix; +} + +TEST_F(ValidateMemory, CoopVecMatMulSuccess) { + std::string spirv = GenCoopVecShader("", + R"( +%result0 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input4 %interp %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false +%result1 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input8 %interp %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_8 %s32_0 %false +%result2 = OpCooperativeVectorMatrixMulAddNV %f16vec8 %input4 %interp %array_ptr %offset %interp %array_ptr %offset %interp %u32_8 %u32_4 %s32_0 %false +%result3 = OpCooperativeVectorMatrixMulNV %f16vec4 %input4 %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false +%result4 = OpCooperativeVectorMatrixMulNV %f16vec4 %input8 %interp %array_ptr %offset %interp %u32_4 %u32_8 %s32_0 %false +%result5 = OpCooperativeVectorMatrixMulNV %f16vec8 %input4 %interp %array_ptr %offset %interp %u32_8 %u32_4 %s32_0 %false + +OpCooperativeVectorReduceSumAccumulateNV %array_ptr %offset %input4 +OpCooperativeVectorOuterProductAccumulateNV %array_ptr %offset %input4 %input8 %interp %interp + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); +} + +TEST_F(ValidateMemory, CoopVecMatMulKMismatchFail) { + std::string spirv = GenCoopVecShader(R"()", + R"( +%result1 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input8 %interp %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV input number of " + "components 8 does not match K 4")); +} + +TEST_F(ValidateMemory, CoopVecMatMulPackedKMismatchPass) { + std::string spirv = GenCoopVecShader( + R"( +%packed = OpConstant %u32 1000491001 + )", + R"( +%result1 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input8 %packed %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); +} + +TEST_F(ValidateMemory, CoopVecMatMulMMismatchFail) { + std::string spirv = GenCoopVecShader(R"()", + R"( +%result1 = OpCooperativeVectorMatrixMulAddNV %f16vec8 %input8 %interp %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_8 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV result type number " + "of components 8 does not match M 4")); +} + +TEST_F(ValidateMemory, CoopVecMatMulTransposeTypeFail) { + std::string spirv = GenCoopVecShader(R"()", + R"( +%result0 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input4 %interp %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %s32_0 + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV Transpose " + "'16[%int_0]' is not a scalar boolean")); +} + +TEST_F(ValidateMemory, CoopVecMatMulInputInterpretationNotConstantFail) { + std::string spirv = GenCoopVecShader( + R"( + )", + R"( +%u32val = OpLoad %u32 %u32var +%result0 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input4 %u32val %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV InputInterpretation " + " '31[%31]' is not a constant instruction")); +} + +TEST_F(ValidateMemory, CoopVecMatMulMatrixInterpretationNotConstantFail) { + std::string spirv = GenCoopVecShader( + R"( + )", + R"( +%u32val = OpLoad %u32 %u32var +%result0 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input4 %interp %array_ptr %offset %u32val %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV MatrixInterpretation " + "'31[%31]' is not a constant instruction")); +} + +TEST_F(ValidateMemory, CoopVecMatMulBiasInterpretationNotConstantFail) { + std::string spirv = GenCoopVecShader( + R"( + )", + R"( +%u32val = OpLoad %u32 %u32var +%result0 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input4 %interp %array_ptr %offset %interp %array_ptr %offset %u32val %u32_4 %u32_4 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV BiasInterpretation " + " '31[%31]' is not a constant instruction")); +} + +TEST_F(ValidateMemory, CoopVecMatMulInputInterpretationNotInt32Fail) { + std::string spirv = GenCoopVecShader( + R"( + )", + R"( +%result0 = OpCooperativeVectorMatrixMulAddNV %f16vec4 %input4 %false %array_ptr %offset %interp %array_ptr %offset %interp %u32_4 %u32_4 %s32_0 %false + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorMatrixMulAddNV InputInterpretation " + "type '12[%bool]' is not a 32 bit integer")); +} + +TEST_F(ValidateMemory, CoopVecOuterProductABMismatchFail) { + std::string spirv = GenCoopVecShader( + R"( +%f32 = OpTypeFloat 32 +%f32vec8 = OpTypeCooperativeVectorNV %f32 %u32_8 +%f32_0 = OpConstant %f32 0 +%input8f32 = OpConstantCompositeReplicateEXT %f32vec8 %f32_0 + )", + R"( +OpCooperativeVectorOuterProductAccumulateNV %array_ptr %offset %input4 %input8f32 %interp %interp + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpCooperativeVectorOuterProductAccumulateNV A and B component " + "types '11[%half]' and '28[%float]' do not match")); +} + +TEST_F(ValidateMemory, CoopVecOuterProductInt32OffsetFail) { + std::string spirv = GenCoopVecShader( + R"( +%u64 = OpTypeInt 64 0 +%u64_0 = OpConstant %u64 0 + )", + R"( +OpCooperativeVectorOuterProductAccumulateNV %array_ptr %u64_0 %input4 %input8 %interp %interp + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorOuterProductAccumulateNV Offset " + "type '28[%ulong]' is not a 32 bit integer")); +} + +TEST_F(ValidateMemory, CoopVecOuterProductInt32MatrixStrideFail) { + std::string spirv = GenCoopVecShader( + R"( +%u64 = OpTypeInt 64 0 +%u64_0 = OpConstant %u64 0 + )", + R"( +OpCooperativeVectorOuterProductAccumulateNV %array_ptr %offset %input4 %input8 %interp %interp %u64_0 + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpCooperativeVectorOuterProductAccumulateNV MatrixStride type " + " '28[%ulong]' is not a 32 bit integer")); +} + +TEST_F(ValidateMemory, CoopVecOuterProductVectorTypeFail) { + std::string spirv = GenCoopVecShader( + R"( +%f16v4 = OpTypeVector %f16 4 +%f16c = OpConstantCompositeReplicateEXT %f16v4 %f16_0 + )", + R"( +OpCooperativeVectorOuterProductAccumulateNV %array_ptr %offset %f16c %input8 %interp %interp + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorOuterProductAccumulateNV A type " + " '28[%v4half]' is not a cooperative vector type")); +} + +TEST_F(ValidateMemory, CoopVecReduceSumInt32OffsetFail) { + std::string spirv = GenCoopVecShader( + R"( +%u64 = OpTypeInt 64 0 +%u64_0 = OpConstant %u64 0 + )", + R"( +OpCooperativeVectorReduceSumAccumulateNV %array_ptr %u64_0 %input4 + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorReduceSumAccumulateNV Offset type " + " '28[%ulong]' is not a 32 bit integer")); +} + +TEST_F(ValidateMemory, CoopVecReduceSumVectorTypeFail) { + std::string spirv = GenCoopVecShader( + R"( +%f16v4 = OpTypeVector %f16 4 +%f16c = OpConstantCompositeReplicateEXT %f16v4 %f16_0 + )", + R"( +OpCooperativeVectorReduceSumAccumulateNV %array_ptr %offset %f16c + )"); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1_SPIRV_1_4); + ASSERT_EQ(SPV_ERROR_INVALID_ID, + ValidateInstructions(SPV_ENV_VULKAN_1_1_SPIRV_1_4)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpCooperativeVectorReduceSumAccumulateNV V type " + "'28[%v4half]' is not a cooperative vector type.")); +} } // namespace } // namespace val } // namespace spvtools