From 389149fb32bfbd8640bda49605520cd0da80a952 Mon Sep 17 00:00:00 2001 From: Jeff Noyle Date: Thu, 15 Dec 2022 16:36:09 -0800 Subject: [PATCH 1/7] PIX: Modify root sigs in place (plus fix root sig memory leak) (#4876) PIX is unique in that it needs to deserialize, modify and then reserialize root sigs. The focus of this checkin is adding such modifications to PixPassHelpers.cpp. (This closes a gap in PIX support: PIX can now support shaders that use dxil-defined and attribute-style root signatures.) But this required some work outside of the purely PIX-focused areas. Deserialized root sigs are described by a C-like structure with embedded arrays that are new/delete-ed by routines in DxilRootSignature.cpp. Since modifying these structs requires more new/delete calls, I chose to add a new entry point in DxilRootSignature.cpp to do the bare minimum that PIX needs: extend root params by one descriptor. This approach keeps all those raw new/deletes in one file. I found a leak in DxilRootSignatureSerialzier.cpp, which I fixed. There appear to be no unit tests that exercise this path. I welcome feedback on adding one. There were other leaks in class CVersionedRootSignatureDeserializer, but this class is unused so I deleted it. Oh, and there are bazillions of commits because I was cherry-picking from a recent change (#4845) as it eveolved, since I needed that change and this to test PIX. (cherry picked from commit 20bb3d0228abb576423f70d1d870f9c48342f668) --- .../dxc/DxilRootSignature/DxilRootSignature.h | 37 +++ lib/DxilPIXPasses/PixPassHelpers.cpp | 99 +++++- .../DxilRootSignatureSerializer.cpp | 87 ----- lib/HLSL/DxilValidation.cpp | 9 +- tools/clang/unittests/HLSL/PixTest.cpp | 311 ++++++++++++++---- 5 files changed, 384 insertions(+), 159 deletions(-) diff --git a/include/dxc/DxilRootSignature/DxilRootSignature.h b/include/dxc/DxilRootSignature/DxilRootSignature.h index 7b0c89ee99..87d0bf9b0f 100644 --- a/include/dxc/DxilRootSignature/DxilRootSignature.h +++ b/include/dxc/DxilRootSignature/DxilRootSignature.h @@ -385,6 +385,43 @@ bool VerifyRootSignature(_In_ const DxilVersionedRootSignatureDesc *pDesc, _In_ llvm::raw_ostream &DiagStream, _In_ bool bAllowReservedRegisterSpace); +class DxilVersionedRootSignature { + DxilVersionedRootSignatureDesc *m_pRootSignature; + +public: + // Non-copyable: + DxilVersionedRootSignature(DxilVersionedRootSignature const &) = delete; + DxilVersionedRootSignature const & + operator=(DxilVersionedRootSignature const &) = delete; + + // but movable: + DxilVersionedRootSignature(DxilVersionedRootSignature &&) = default; + DxilVersionedRootSignature & + operator=(DxilVersionedRootSignature &&) = default; + + DxilVersionedRootSignature() : m_pRootSignature(nullptr) {} + explicit DxilVersionedRootSignature( + const DxilVersionedRootSignatureDesc *pRootSignature) + : m_pRootSignature( + const_cast (pRootSignature)) {} + ~DxilVersionedRootSignature() { + DeleteRootSignature(m_pRootSignature); + } + const DxilVersionedRootSignatureDesc* operator -> () const { + return m_pRootSignature; + } + const DxilVersionedRootSignatureDesc ** get_address_of() { + if (m_pRootSignature != nullptr) + return nullptr; // You're probably about to leak... + return const_cast (&m_pRootSignature); + } + const DxilVersionedRootSignatureDesc* get() const { + return m_pRootSignature; + } + DxilVersionedRootSignatureDesc* get_mutable() const { + return m_pRootSignature; + } +}; } // namespace hlsl #endif // __DXC_ROOTSIGNATURE__ diff --git a/lib/DxilPIXPasses/PixPassHelpers.cpp b/lib/DxilPIXPasses/PixPassHelpers.cpp index 28a9bd19f1..fbc80ddb0b 100644 --- a/lib/DxilPIXPasses/PixPassHelpers.cpp +++ b/lib/DxilPIXPasses/PixPassHelpers.cpp @@ -9,10 +9,12 @@ #include "dxc/DXIL/DxilOperations.h" #include "dxc/DXIL/DxilInstructions.h" +#include "dxc/DXIL/DxilFunctionProps.h" #include "dxc/DXIL/DxilModule.h" #include "dxc/DXIL/DxilResourceBinding.h" #include "dxc/DXIL/DxilResourceProperties.h" #include "dxc/HLSL/DxilSpanAllocator.h" +#include "dxc/DxilRootSignature/DxilRootSignature.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -21,6 +23,10 @@ #include "PixPassHelpers.h" +#include "dxc/Support/Global.h" +#include "dxc/Support/WinIncludes.h" +#include "dxc/dxcapi.h" + #ifdef PIX_DEBUG_DUMP_HELPER #include #include "llvm/IR/DebugInfo.h" @@ -159,7 +165,93 @@ llvm::CallInst *CreateHandleForResource(hlsl::DxilModule &DM, } } -// Set up a UAV with structure of a single int +static std::vector SerializeRootSignatureToVector(DxilVersionedRootSignatureDesc const *rootSignature) { + CComPtr serializedRootSignature; + CComPtr errorBlob; + constexpr bool allowReservedRegisterSpace = true; + SerializeRootSignature(rootSignature, &serializedRootSignature, &errorBlob, + allowReservedRegisterSpace); + std::vector ret; + auto const *serializedData = reinterpret_cast( + serializedRootSignature->GetBufferPointer()); + ret.assign(serializedData, + serializedData + serializedRootSignature->GetBufferSize()); + + return ret; +} + +constexpr uint32_t toolsRegisterSpace = static_cast(-2); +constexpr uint32_t toolsUAVRegister = 0; + +template +void ExtendRootSig(RootSigDesc &rootSigDesc) { + auto *existingParams = rootSigDesc.pParameters; + auto *newParams = new RootParameterDesc[rootSigDesc.NumParameters + 1]; + if (existingParams != nullptr) { + memcpy(newParams, existingParams, + rootSigDesc.NumParameters * sizeof(RootParameterDesc)); + delete[] existingParams; + } + rootSigDesc.pParameters = newParams; + rootSigDesc.pParameters[rootSigDesc.NumParameters].ParameterType = DxilRootParameterType::UAV; + rootSigDesc.pParameters[rootSigDesc.NumParameters].Descriptor.RegisterSpace = toolsRegisterSpace; + rootSigDesc.pParameters[rootSigDesc.NumParameters].Descriptor.ShaderRegister = toolsUAVRegister; + rootSigDesc.pParameters[rootSigDesc.NumParameters].ShaderVisibility = DxilShaderVisibility::All; + rootSigDesc.NumParameters++; +} + +static std::vector AddUAVParamterToRootSignature(const void *Data, + uint32_t Size) { + DxilVersionedRootSignature rootSignature; + DeserializeRootSignature(Data, Size, rootSignature.get_address_of()); + auto *rs = rootSignature.get_mutable(); + switch (rootSignature->Version) { + case DxilRootSignatureVersion::Version_1_0: + ExtendRootSig(rs->Desc_1_0); + break; + case DxilRootSignatureVersion::Version_1_1: + ExtendRootSig(rs->Desc_1_1); + rs->Desc_1_1.pParameters[rs->Desc_1_1.NumParameters - 1].Descriptor.Flags = + hlsl::DxilRootDescriptorFlags::None; + break; + } + return SerializeRootSignatureToVector(rs); +} + +static void AddUAVToShaderAttributeRootSignature(DxilModule &DM) { + auto rs = DM.GetSerializedRootSignature(); + if(!rs.empty()) { + std::vector asVector = AddUAVParamterToRootSignature(rs.data(), static_cast(rs.size())); + DM.ResetSerializedRootSignature(asVector); + } +} + +static void AddUAVToDxilDefinedGlobalRootSignatures(DxilModule& DM) { + auto *subObjects = DM.GetSubobjects(); + if (subObjects != nullptr) { + for (auto const &subObject : subObjects->GetSubobjects()) { + if (subObject.second->GetKind() == + DXIL::SubobjectKind::GlobalRootSignature) { + const void *Data = nullptr; + uint32_t Size = 0; + constexpr bool notALocalRS = false; + if (subObject.second->GetRootSignature(notALocalRS, Data, Size, + nullptr)) { + auto extendedRootSig = AddUAVParamterToRootSignature(Data, Size); + auto rootSignatureSubObjectName = subObject.first; + subObjects->RemoveSubobject(rootSignatureSubObjectName); + subObjects->CreateRootSignature(rootSignatureSubObjectName, + notALocalRS, + extendedRootSig.data(), + static_cast(extendedRootSig.size())); + break; + } + } + } + } +} + + // Set up a UAV with structure of a single int llvm::CallInst *CreateUAV(DxilModule &DM, IRBuilder<> &Builder, unsigned int registerId, const char *name) { LLVMContext &Ctx = DM.GetModule()->getContext(); @@ -170,6 +262,11 @@ llvm::CallInst *CreateUAV(DxilModule &DM, IRBuilder<> &Builder, if (UAVStructTy == nullptr) { SmallVector Elements{Type::getInt32Ty(Ctx)}; UAVStructTy = llvm::StructType::create(Elements, PIXStructTypeName); + + // Since we only have to do this once per module, we can do it now when + // we're adding the singular UAV structure type to the module: + AddUAVToDxilDefinedGlobalRootSignatures(DM); + AddUAVToShaderAttributeRootSignature(DM); } std::unique_ptr pUAV = llvm::make_unique(); diff --git a/lib/DxilRootSignature/DxilRootSignatureSerializer.cpp b/lib/DxilRootSignature/DxilRootSignatureSerializer.cpp index 8100626cb6..a83beb034a 100644 --- a/lib/DxilRootSignature/DxilRootSignatureSerializer.cpp +++ b/lib/DxilRootSignature/DxilRootSignatureSerializer.cpp @@ -331,93 +331,6 @@ void SerializeRootSignature(const DxilVersionedRootSignatureDesc *pRootSignature } } -//============================================================================= -// -// CVersionedRootSignatureDeserializer. -// -//============================================================================= -class CVersionedRootSignatureDeserializer { -protected: - const DxilVersionedRootSignatureDesc *m_pRootSignature; - const DxilVersionedRootSignatureDesc *m_pRootSignature10; - const DxilVersionedRootSignatureDesc *m_pRootSignature11; - -public: - CVersionedRootSignatureDeserializer(); - ~CVersionedRootSignatureDeserializer(); - - void Initialize(_In_reads_bytes_(SrcDataSizeInBytes) const void *pSrcData, - _In_ uint32_t SrcDataSizeInBytes); - - const DxilVersionedRootSignatureDesc *GetRootSignatureDescAtVersion(DxilRootSignatureVersion convertToVersion); - - const DxilVersionedRootSignatureDesc *GetUnconvertedRootSignatureDesc(); -}; - -CVersionedRootSignatureDeserializer::CVersionedRootSignatureDeserializer() - : m_pRootSignature(nullptr) - , m_pRootSignature10(nullptr) - , m_pRootSignature11(nullptr) { -} - -CVersionedRootSignatureDeserializer::~CVersionedRootSignatureDeserializer() { - DeleteRootSignature(m_pRootSignature10); - DeleteRootSignature(m_pRootSignature11); -} - -void CVersionedRootSignatureDeserializer::Initialize(_In_reads_bytes_(SrcDataSizeInBytes) const void *pSrcData, - _In_ uint32_t SrcDataSizeInBytes) { - const DxilVersionedRootSignatureDesc *pRootSignature = nullptr; - DeserializeRootSignature(pSrcData, SrcDataSizeInBytes, &pRootSignature); - - switch (pRootSignature->Version) { - case DxilRootSignatureVersion::Version_1_0: - m_pRootSignature10 = pRootSignature; - break; - - case DxilRootSignatureVersion::Version_1_1: - m_pRootSignature11 = pRootSignature; - break; - - default: - DeleteRootSignature(pRootSignature); - return; - } - - m_pRootSignature = pRootSignature; -} - -const DxilVersionedRootSignatureDesc * -CVersionedRootSignatureDeserializer::GetUnconvertedRootSignatureDesc() { - return m_pRootSignature; -} - -const DxilVersionedRootSignatureDesc * -CVersionedRootSignatureDeserializer::GetRootSignatureDescAtVersion(DxilRootSignatureVersion ConvertToVersion) { - switch (ConvertToVersion) { - case DxilRootSignatureVersion::Version_1_0: - if (m_pRootSignature10 == nullptr) { - ConvertRootSignature(m_pRootSignature, - ConvertToVersion, - (const DxilVersionedRootSignatureDesc **)&m_pRootSignature10); - } - return m_pRootSignature10; - - case DxilRootSignatureVersion::Version_1_1: - if (m_pRootSignature11 == nullptr) { - ConvertRootSignature(m_pRootSignature, - ConvertToVersion, - (const DxilVersionedRootSignatureDesc **)&m_pRootSignature11); - } - return m_pRootSignature11; - - default: - IFTBOOL(false, E_FAIL); - } - - return nullptr; -} - templateReserve(pWriter->size()); pWriter->write(pOutputStream); + DxilVersionedRootSignature desc; try { - const DxilVersionedRootSignatureDesc* pDesc = nullptr; - DeserializeRootSignature(SerializedRootSig.data(), SerializedRootSig.size(), &pDesc); - if (!pDesc) { + DeserializeRootSignature(SerializedRootSig.data(), + SerializedRootSig.size(), desc.get_address_of()); + if (!desc.get()) { return DXC_E_INCORRECT_ROOT_SIGNATURE; } - IFTBOOL(VerifyRootSignatureWithShaderPSV(pDesc, + IFTBOOL(VerifyRootSignatureWithShaderPSV(desc.get(), dxilModule.GetShaderModel()->GetKind(), pOutputStream->GetPtr(), pWriter->size(), DiagStream), DXC_E_INCORRECT_ROOT_SIGNATURE); diff --git a/tools/clang/unittests/HLSL/PixTest.cpp b/tools/clang/unittests/HLSL/PixTest.cpp index 10bcdd6c23..a96cec8444 100644 --- a/tools/clang/unittests/HLSL/PixTest.cpp +++ b/tools/clang/unittests/HLSL/PixTest.cpp @@ -25,7 +25,12 @@ #include #include #include + +#include "dxc/Support/WinIncludes.h" + #include "dxc/DxilContainer/DxilContainer.h" +#include "dxc/DxilContainer/DxilRuntimeReflection.h" +#include "dxc/DxilRootSignature/DxilRootSignature.h" #include "dxc/Support/WinIncludes.h" #include "dxc/dxcapi.h" #include "dxc/dxcpix.h" @@ -195,10 +200,6 @@ class PixTest { TEST_METHOD(DiaCompileArgs) TEST_METHOD(PixDebugCompileInfo) - TEST_METHOD(CheckSATPassFor66_NoDynamicAccess) - TEST_METHOD(CheckSATPassFor66_DynamicFromRootSig) - TEST_METHOD(CheckSATPassFor66_DynamicFromHeap) - TEST_METHOD(AddToASPayload) TEST_METHOD(PixStructAnnotation_Lib_DualRaygen) @@ -223,6 +224,9 @@ class PixTest { TEST_METHOD(VirtualRegisters_InstructionCounts) + TEST_METHOD(RootSignatureUpgrade_SubObjects) + TEST_METHOD(RootSignatureUpgrade_Annotation) + dxc::DxcDllSupport m_dllSupport; VersionSupportInfo m_ver; @@ -1003,8 +1007,9 @@ class PixTest { bool validateCoverage = true, const wchar_t *profile = L"as_6_5"); void ValidateAllocaWrite(std::vector const& allocaWrites, size_t index, const char* name); - std::string RunShaderAccessTrackingPassAndReturnOutputMessages(IDxcBlob* blob); - std::string RunDxilPIXAddTidToAmplificationShaderPayloadPass(IDxcBlob* blob); + CComPtr RunShaderAccessTrackingPass(IDxcBlob* blob); + std::string RunDxilPIXAddTidToAmplificationShaderPayloadPass(IDxcBlob * + blob); CComPtr RunDxilPIXMeshShaderOutputPass(IDxcBlob* blob); }; @@ -1665,81 +1670,35 @@ TEST_F(PixTest, PixDebugCompileInfo) { VERIFY_ARE_EQUAL(std::wstring(profile), std::wstring(hlslTarget)); } -std::string PixTest::RunShaderAccessTrackingPassAndReturnOutputMessages(IDxcBlob* blob) -{ - CComPtr dxil = FindModule(DFCC_ShaderDebugInfoDXIL, blob); +CComPtr PixTest::RunShaderAccessTrackingPass(IDxcBlob *blob) { CComPtr pOptimizer; VERIFY_SUCCEEDED( m_dllSupport.CreateInstance(CLSID_DxcOptimizer, &pOptimizer)); std::vector Options; Options.push_back(L"-opt-mod-passes"); - Options.push_back(L"-hlsl-dxil-pix-shader-access-instrumentation,config=,checkForDynamicIndexing=1"); + Options.push_back(L"-hlsl-dxil-pix-shader-access-instrumentation,config="); CComPtr pOptimizedModule; CComPtr pText; VERIFY_SUCCEEDED(pOptimizer->RunOptimizer( - dxil, Options.data(), Options.size(), &pOptimizedModule, &pText)); - - std::string outputText; - if (pText->GetBufferSize() != 0) { - outputText = reinterpret_cast(pText->GetBufferPointer()); - } + blob, Options.data(), Options.size(), &pOptimizedModule, &pText)); - return outputText; -} - -TEST_F(PixTest, CheckSATPassFor66_NoDynamicAccess) { - - const char *noDynamicAccess = R"( - [RootSignature("")] - float main(float pos : A) : SV_Target { - float x = abs(pos); - float y = sin(pos); - float z = x + y; - return z; - } - )"; - - auto compiled = Compile(noDynamicAccess, L"ps_6_6"); - auto satResults = RunShaderAccessTrackingPassAndReturnOutputMessages(compiled); - VERIFY_IS_TRUE(satResults.empty()); -} - -TEST_F(PixTest, CheckSATPassFor66_DynamicFromRootSig) { - - const char *dynamicTextureAccess = R"x( -Texture1D tex[5] : register(t3); -SamplerState SS[3] : register(s2); - -[RootSignature("DescriptorTable(SRV(t3, numDescriptors=5)),\ - DescriptorTable(Sampler(s2, numDescriptors=3))")] -float4 main(int i : A, float j : B) : SV_TARGET -{ - float4 r = tex[i].Sample(SS[i], i); - return r; -} - )x"; + CComPtr pAssembler; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcAssembler, &pAssembler)); - auto compiled = Compile(dynamicTextureAccess, L"ps_6_6"); - auto satResults = RunShaderAccessTrackingPassAndReturnOutputMessages(compiled); - VERIFY_IS_TRUE(satResults.find("FoundDynamicIndexing") != string::npos); -} + CComPtr pAssembleResult; + VERIFY_SUCCEEDED( + pAssembler->AssembleToContainer(pOptimizedModule, &pAssembleResult)); -TEST_F(PixTest, CheckSATPassFor66_DynamicFromHeap) { + HRESULT hr; + VERIFY_SUCCEEDED(pAssembleResult->GetStatus(&hr)); + VERIFY_SUCCEEDED(hr); - const char *dynamicResourceDecriptorHeapAccess = R"( -static sampler sampler0 = SamplerDescriptorHeap[0]; -float4 main(int input : INPUT) : SV_Target -{ - Texture2D texture = ResourceDescriptorHeap[input]; - return texture.Sample(sampler0, float2(0,0)); -} - )"; + CComPtr pNewContainer; + VERIFY_SUCCEEDED(pAssembleResult->GetResult(&pNewContainer)); - auto compiled = Compile(dynamicResourceDecriptorHeapAccess, L"ps_6_6"); - auto satResults = - RunShaderAccessTrackingPassAndReturnOutputMessages(compiled); - VERIFY_IS_TRUE(satResults.find("FoundDynamicIndexing") != string::npos); + return pNewContainer; } CComPtr PixTest::RunDxilPIXMeshShaderOutputPass(IDxcBlob *blob) { @@ -3426,4 +3385,222 @@ void MyMissShader(inout RayPayload payload) } } +static void VerifyOperationSucceeded(IDxcOperationResult *pResult) +{ + HRESULT result; + VERIFY_SUCCEEDED(pResult->GetStatus(&result)); + if (FAILED(result)) { + CComPtr pErrors; + VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&pErrors)); + CA2W errorsWide(BlobToUtf8(pErrors).c_str(), CP_UTF8); + WEX::Logging::Log::Comment(errorsWide); + } + VERIFY_SUCCEEDED(result); +} + +TEST_F(PixTest, RootSignatureUpgrade_SubObjects) { + + const char *source = R"x( +GlobalRootSignature so_GlobalRootSignature = +{ + "RootConstants(num32BitConstants=1, b8), " +}; + +StateObjectConfig so_StateObjectConfig = +{ + STATE_OBJECT_FLAGS_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS +}; + +LocalRootSignature so_LocalRootSignature1 = +{ + "RootConstants(num32BitConstants=3, b2), " + "UAV(u6),RootFlags(LOCAL_ROOT_SIGNATURE)" +}; + +LocalRootSignature so_LocalRootSignature2 = +{ + "RootConstants(num32BitConstants=3, b2), " + "UAV(u8, flags=DATA_STATIC), " + "RootFlags(LOCAL_ROOT_SIGNATURE)" +}; + +RaytracingShaderConfig so_RaytracingShaderConfig = +{ + 128, // max payload size + 32 // max attribute size +}; + +RaytracingPipelineConfig so_RaytracingPipelineConfig = +{ + 2 // max trace recursion depth +}; + +TriangleHitGroup MyHitGroup = +{ + "MyAnyHit", // AnyHit + "MyClosestHit", // ClosestHit +}; + +SubobjectToExportsAssociation so_Association1 = +{ + "so_LocalRootSignature1", // subobject name + "MyRayGen" // export association +}; + +SubobjectToExportsAssociation so_Association2 = +{ + "so_LocalRootSignature2", // subobject name + "MyAnyHit" // export association +}; + +struct MyPayload +{ + float4 color; +}; + +[shader("raygeneration")] +void MyRayGen() +{ +} + +[shader("closesthit")] +void MyClosestHit(inout MyPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ +} + +[shader("anyhit")] +void MyAnyHit(inout MyPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ +} + +[shader("miss")] +void MyMiss(inout MyPayload payload) +{ +} + +)x"; + + CComPtr pCompiler; + VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler)); + + CComPtr pSource; + Utf8ToBlob(m_dllSupport, source, &pSource); + + CComPtr pResult; + VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"source.hlsl", L"", L"lib_6_6", + nullptr, 0, nullptr, 0, nullptr, + &pResult)); + VerifyOperationSucceeded(pResult); + CComPtr compiled; + VERIFY_SUCCEEDED(pResult->GetResult(&compiled)); + + auto optimizedContainer = RunShaderAccessTrackingPass(compiled); + + const char *pBlobContent = + reinterpret_cast(optimizedContainer->GetBufferPointer()); + unsigned blobSize = optimizedContainer->GetBufferSize(); + const hlsl::DxilContainerHeader *pContainerHeader = + hlsl::IsDxilContainerLike(pBlobContent, blobSize); + + const hlsl::DxilPartHeader *pPartHeader = + GetDxilPartByType(pContainerHeader, hlsl::DFCC_RuntimeData); + VERIFY_ARE_NOT_EQUAL(pPartHeader, nullptr); + + hlsl::RDAT::DxilRuntimeData rdat(GetDxilPartData(pPartHeader), + pPartHeader->PartSize); + + auto const subObjectTableReader = rdat.GetSubobjectTable(); + + // There are 9 subobjects in the HLSL above: + VERIFY_ARE_EQUAL(subObjectTableReader.Count(), 9u); + + bool foundGlobalRS = false; + for (uint32_t i = 0; i < subObjectTableReader.Count(); ++i) { + auto subObject = subObjectTableReader[i]; + hlsl::DXIL::SubobjectKind subobjectKind = subObject.getKind(); + switch (subobjectKind) { + case hlsl::DXIL::SubobjectKind::GlobalRootSignature: { + foundGlobalRS = true; + VERIFY_IS_TRUE(0 == + strcmp(subObject.getName(), "so_GlobalRootSignature")); + + auto rootSigReader = subObject.getRootSignature(); + DxilVersionedRootSignatureDesc const *rootSignature = nullptr; + DeserializeRootSignature(rootSigReader.getData(), + rootSigReader.sizeData(), &rootSignature); + VERIFY_ARE_EQUAL(rootSignature->Version, + DxilRootSignatureVersion::Version_1_1); + VERIFY_ARE_EQUAL(rootSignature->Desc_1_1.NumParameters, 2); + VERIFY_ARE_EQUAL(rootSignature->Desc_1_1.pParameters[1].ParameterType, + DxilRootParameterType::UAV); + VERIFY_ARE_EQUAL(rootSignature->Desc_1_1.pParameters[1].ShaderVisibility, + DxilShaderVisibility::All); + VERIFY_ARE_EQUAL( + rootSignature->Desc_1_1.pParameters[1].Descriptor.RegisterSpace, + static_cast(-2)); + VERIFY_ARE_EQUAL( + rootSignature->Desc_1_1.pParameters[1].Descriptor.ShaderRegister, 0u); + DeleteRootSignature(rootSignature); + break; + } + } + } + VERIFY_IS_TRUE(foundGlobalRS); +} + +TEST_F(PixTest, RootSignatureUpgrade_Annotation) +{ + + const char *dynamicTextureAccess = R"x( +Texture1D tex[5] : register(t3); +SamplerState SS[3] : register(s2); + +[RootSignature("DescriptorTable(SRV(t3, numDescriptors=5)),\ + DescriptorTable(Sampler(s2, numDescriptors=3))")] +float4 main(int i : A, float j : B) : SV_TARGET +{ + float4 r = tex[i].Sample(SS[i], i); + return r; +} + )x"; + + auto compiled = Compile(dynamicTextureAccess, L"ps_6_6"); + auto pOptimizedContainer = RunShaderAccessTrackingPass(compiled); + + const char *pBlobContent = + reinterpret_cast(pOptimizedContainer->GetBufferPointer()); + unsigned blobSize = pOptimizedContainer->GetBufferSize(); + const hlsl::DxilContainerHeader *pContainerHeader = + hlsl::IsDxilContainerLike(pBlobContent, blobSize); + + const hlsl::DxilPartHeader *pPartHeader = + GetDxilPartByType(pContainerHeader, hlsl::DFCC_RootSignature); + VERIFY_ARE_NOT_EQUAL(pPartHeader, nullptr); + + hlsl::RootSignatureHandle RSH; + RSH.LoadSerialized((const uint8_t *)GetDxilPartData(pPartHeader), + pPartHeader->PartSize); + + RSH.Deserialize(); + + auto const *desc = RSH.GetDesc(); + + bool foundGlobalRS = false; + + VERIFY_ARE_EQUAL(desc->Version, hlsl::DxilRootSignatureVersion::Version_1_1); + VERIFY_ARE_EQUAL(desc->Desc_1_1.NumParameters, 3u); + for (unsigned int i = 0; i < desc->Desc_1_1.NumParameters; ++i) { + hlsl::DxilRootParameter1 const *param = desc->Desc_1_1.pParameters + i; + switch (param->ParameterType) { + case hlsl::DxilRootParameterType::UAV: + VERIFY_ARE_EQUAL(param->Descriptor.RegisterSpace, static_cast(-2)); + VERIFY_ARE_EQUAL(param->Descriptor.ShaderRegister, 0u); + foundGlobalRS = true; + break; + } + } + + VERIFY_IS_TRUE(foundGlobalRS); +} + #endif From 139576e85c108ee548ad0b5e8007f0aa4408cab2 Mon Sep 17 00:00:00 2001 From: Tex Riddell Date: Mon, 12 Dec 2022 18:54:10 -0800 Subject: [PATCH 2/7] dxcopt: Support full container and restore extra data to module (#4845) This modifies IDxcOptimizer::RunOptimizier to accept full DxilContainer input. When full container input is used, this restores some data that is stripped from the module and placed in various other container parts. Data restored: - Subobjects from RDAT - RootSignature from RTS0 - ViewID and I/O dependency data from PSV0 - Resource names and types/annotations from STAT Serialization of these to metadata in module bitcode output still requires hlsl-dxilemit step. (cherry picked from commit 2c3d965b2fc734c09e76631de65f8ffbe4e3e68b) --- include/dxc/DXIL/DxilModule.h | 3 + .../DxilContainer/DxilContainerAssembler.h | 11 + lib/DXIL/DxilModule.cpp | 56 + lib/DxilContainer/DxilContainerAssembler.cpp | 208 ++- lib/HLSL/DxcOptimizer.cpp | 156 ++- tools/clang/unittests/HLSL/OptimizerTest.cpp | 1231 ++++++++++++++++- 6 files changed, 1597 insertions(+), 68 deletions(-) diff --git a/include/dxc/DXIL/DxilModule.h b/include/dxc/DXIL/DxilModule.h index 8ed10d9a39..17a3ec4776 100644 --- a/include/dxc/DXIL/DxilModule.h +++ b/include/dxc/DXIL/DxilModule.h @@ -204,6 +204,9 @@ class DxilModule { void StripDebugRelatedCode(); void RemoveUnusedTypeAnnotations(); + // Copy resource reflection back to this module's resources. + void RestoreResourceReflection(const DxilModule &SourceDM); + // Helper to remove dx.* metadata with source and compile options. // If the parameter `bReplaceWithDummyData` is true, the named metadata // are replaced with valid empty data that satisfy tools. diff --git a/include/dxc/DxilContainer/DxilContainerAssembler.h b/include/dxc/DxilContainer/DxilContainerAssembler.h index b2ae968323..c51f97dbec 100644 --- a/include/dxc/DxilContainer/DxilContainerAssembler.h +++ b/include/dxc/DxilContainer/DxilContainerAssembler.h @@ -16,6 +16,7 @@ #include "llvm/ADT/StringRef.h" struct IStream; +class DxilPipelineStateValidation; namespace llvm { class Module; @@ -51,6 +52,16 @@ DxilPartWriter *NewFeatureInfoWriter(const DxilModule &M); DxilPartWriter *NewPSVWriter(const DxilModule &M, uint32_t PSVVersion = UINT_MAX); DxilPartWriter *NewRDATWriter(const DxilModule &M); +// Store serialized ViewID data from DxilModule to PipelineStateValidation. +void StoreViewIDStateToPSV(const uint32_t *pInputData, + unsigned InputSizeInUInts, + DxilPipelineStateValidation &PSV); +// Load ViewID state from PSV back to DxilModule view state vector. +// Pass nullptr for pOutputData to compute and return needed OutputSizeInUInts. +unsigned LoadViewIDStateFromPSV(unsigned *pOutputData, + unsigned OutputSizeInUInts, + const DxilPipelineStateValidation &PSV); + // Unaligned is for matching container for validator version < 1.7. DxilContainerWriter *NewDxilContainerWriter(bool bUnaligned = false); diff --git a/lib/DXIL/DxilModule.cpp b/lib/DXIL/DxilModule.cpp index d9f95bf0d8..676a25433d 100644 --- a/lib/DXIL/DxilModule.cpp +++ b/lib/DXIL/DxilModule.cpp @@ -1873,6 +1873,62 @@ void DxilModule::RemoveUnusedTypeAnnotations() { } +template +static void CopyResourceInfo(_T &TargetRes, const _T &SourceRes, + DxilTypeSystem &TargetTypeSys, + const DxilTypeSystem &SourceTypeSys) { + if (TargetRes.GetKind() != SourceRes.GetKind() || + TargetRes.GetLowerBound() != SourceRes.GetLowerBound() || + TargetRes.GetRangeSize() != SourceRes.GetRangeSize() || + TargetRes.GetSpaceID() != SourceRes.GetSpaceID()) { + DXASSERT(false, "otherwise, resource details don't match"); + return; + } + + if (TargetRes.GetGlobalName().empty() && !SourceRes.GetGlobalName().empty()) { + TargetRes.SetGlobalName(SourceRes.GetGlobalName()); + } + + if (TargetRes.GetGlobalSymbol() && SourceRes.GetGlobalSymbol() && + SourceRes.GetGlobalSymbol()->hasName()) { + TargetRes.GetGlobalSymbol()->setName( + SourceRes.GetGlobalSymbol()->getName()); + } + + Type *Ty = SourceRes.GetHLSLType(); + TargetRes.SetHLSLType(Ty); + TargetTypeSys.CopyTypeAnnotation(Ty, SourceTypeSys); +} + +void DxilModule::RestoreResourceReflection(const DxilModule &SourceDM) { + DxilTypeSystem &TargetTypeSys = GetTypeSystem(); + const DxilTypeSystem &SourceTypeSys = SourceDM.GetTypeSystem(); + if (GetCBuffers().size() != SourceDM.GetCBuffers().size() || + GetSRVs().size() != SourceDM.GetSRVs().size() || + GetUAVs().size() != SourceDM.GetUAVs().size() || + GetSamplers().size() != SourceDM.GetSamplers().size()) { + DXASSERT(false, "otherwise, resource lists don't match"); + return; + } + for (unsigned i = 0; i < GetCBuffers().size(); ++i) { + CopyResourceInfo(GetCBuffer(i), SourceDM.GetCBuffer(i), TargetTypeSys, + SourceTypeSys); + } + for (unsigned i = 0; i < GetSRVs().size(); ++i) { + CopyResourceInfo(GetSRV(i), SourceDM.GetSRV(i), TargetTypeSys, + SourceTypeSys); + } + for (unsigned i = 0; i < GetUAVs().size(); ++i) { + CopyResourceInfo(GetUAV(i), SourceDM.GetUAV(i), TargetTypeSys, + SourceTypeSys); + } + for (unsigned i = 0; i < GetSamplers().size(); ++i) { + CopyResourceInfo(GetSampler(i), SourceDM.GetSampler(i), TargetTypeSys, + SourceTypeSys); + } +} + + void DxilModule::LoadDxilResources(const llvm::MDOperand &MDO) { if (MDO.get() == nullptr) return; diff --git a/lib/DxilContainer/DxilContainerAssembler.cpp b/lib/DxilContainer/DxilContainerAssembler.cpp index d9a4ceb4ff..2c8755a6da 100644 --- a/lib/DxilContainer/DxilContainerAssembler.cpp +++ b/lib/DxilContainer/DxilContainerAssembler.cpp @@ -442,6 +442,180 @@ DxilPartWriter *hlsl::NewFeatureInfoWriter(const DxilModule &M) { return new DxilFeatureInfoWriter(M); } + +////////////////////////////////////////////////////////// +// Utility code for serializing/deserializing ViewID state + +// Code for ComputeSeriaizedViewIDStateSizeInUInts copied from +// ComputeViewIdState. It could be moved into some common location if this +// ViewID serialization/deserialization code were moved out of here. +static unsigned RoundUpToUINT(unsigned x) { return (x + 31) / 32; } +static unsigned ComputeSeriaizedViewIDStateSizeInUInts( + const PSVShaderKind SK, const bool bUsesViewID, + const unsigned InputScalars, const unsigned OutputScalars[4], + const unsigned PCScalars) { + // Compute serialized state size in UINTs. + unsigned NumStreams = SK == PSVShaderKind::Geometry ? 4 : 1; + unsigned Size = 0; + Size += 1; // #Inputs. + for (unsigned StreamId = 0; StreamId < NumStreams; StreamId++) { + Size += 1; // #Outputs for stream StreamId. + unsigned NumOutputs = OutputScalars[StreamId]; + unsigned NumOutUINTs = RoundUpToUINT(NumOutputs); + if (bUsesViewID) { + Size += NumOutUINTs; // m_OutputsDependentOnViewId[StreamId] + } + Size += InputScalars * NumOutUINTs; // m_InputsContributingToOutputs[StreamId] + } + if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Domain || SK == PSVShaderKind::Mesh) { + Size += 1; // #PatchConstant. + unsigned NumPCUINTs = RoundUpToUINT(PCScalars); + if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) { + if (bUsesViewID) { + Size += NumPCUINTs; // m_PCOrPrimOutputsDependentOnViewId + } + Size += InputScalars * NumPCUINTs; // m_InputsContributingToPCOrPrimOutputs + } else { + unsigned NumOutputs = OutputScalars[0]; + unsigned NumOutUINTs = RoundUpToUINT(NumOutputs); + Size += PCScalars * NumOutUINTs; // m_PCInputsContributingToOutputs + } + } + return Size; +} + +static const uint32_t *CopyViewIDStateForOutputToPSV( + const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars, + PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) { + unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4); + if (ViewIDMask.IsValid()) { + DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors); + memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords); + pSrc += MaskDwords; + } + if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) { + DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4)); + DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4)); + memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars); + pSrc += MaskDwords * InputScalars; + } + return pSrc; +} + +static uint32_t *CopyViewIDStateForOutputFromPSV(uint32_t *pOutputData, + const unsigned InputScalars, + const unsigned OutputScalars, + PSVComponentMask ViewIDMask, + PSVDependencyTable IOTable) { + unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4); + if (ViewIDMask.IsValid()) { + DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors); + for (unsigned i = 0; i < MaskDwords; i++) + *(pOutputData++) = ViewIDMask.Mask[i]; + } + if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) { + DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4)); + DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4)); + for (unsigned i = 0; i < MaskDwords * InputScalars; i++) + *(pOutputData++) = IOTable.Table[i]; + } + return pOutputData; +} + +void hlsl::StoreViewIDStateToPSV(const uint32_t *pInputData, + unsigned InputSizeInUInts, + DxilPipelineStateValidation &PSV) { + PSVRuntimeInfo1 *pInfo1 = PSV.GetPSVRuntimeInfo1(); + DXASSERT(pInfo1, "otherwise, PSV does not meet version requirement."); + PSVShaderKind SK = static_cast(pInfo1->ShaderStage); + const unsigned OutputStreams = SK == PSVShaderKind::Geometry ? 4 : 1; + const uint32_t *pSrc = pInputData; + const uint32_t InputScalars = *(pSrc++); + uint32_t OutputScalars[4]; + for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) { + OutputScalars[streamIndex] = *(pSrc++); + pSrc = CopyViewIDStateForOutputToPSV( + pSrc, InputScalars, OutputScalars[streamIndex], + PSV.GetViewIDOutputMask(streamIndex), + PSV.GetInputToOutputTable(streamIndex)); + } + if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) { + const uint32_t PCScalars = *(pSrc++); + pSrc = CopyViewIDStateForOutputToPSV(pSrc, InputScalars, PCScalars, + PSV.GetViewIDPCOutputMask(), + PSV.GetInputToPCOutputTable()); + } else if (SK == PSVShaderKind::Domain) { + const uint32_t PCScalars = *(pSrc++); + pSrc = CopyViewIDStateForOutputToPSV(pSrc, PCScalars, OutputScalars[0], + PSVComponentMask(), + PSV.GetPCInputToOutputTable()); + } + DXASSERT(pSrc - pInputData == InputSizeInUInts, + "otherwise, different amout of data written than expected."); +} + +// This function is defined close to the serialization code in DxilPSVWriter to +// reduce the chance of a mismatch. It could be defined elsewhere, but it would +// make sense to move both the serialization and deserialization out of here and +// into a common location. +unsigned hlsl::LoadViewIDStateFromPSV(unsigned *pOutputData, + unsigned OutputSizeInUInts, + const DxilPipelineStateValidation &PSV) { + PSVRuntimeInfo1 *pInfo1 = PSV.GetPSVRuntimeInfo1(); + if (!pInfo1) { + return 0; + } + PSVShaderKind SK = static_cast(pInfo1->ShaderStage); + const unsigned OutputStreams = SK == PSVShaderKind::Geometry ? 4 : 1; + const unsigned InputScalars = pInfo1->SigInputVectors * 4; + unsigned OutputScalars[4]; + for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) { + OutputScalars[streamIndex] = pInfo1->SigOutputVectors[streamIndex] * 4; + } + unsigned PCScalars = 0; + if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh || + SK == PSVShaderKind::Domain) { + PCScalars = pInfo1->SigPatchConstOrPrimVectors * 4; + } + if (pOutputData == nullptr) { + return ComputeSeriaizedViewIDStateSizeInUInts( + SK, pInfo1->UsesViewID != 0, InputScalars, OutputScalars, PCScalars); + } + + // Fill in serialized viewid buffer. + DXASSERT(ComputeSeriaizedViewIDStateSizeInUInts( + SK, pInfo1->UsesViewID != 0, InputScalars, OutputScalars, + PCScalars) == OutputSizeInUInts, + "otherwise, OutputSize doesn't match computed size."); + unsigned *pStartOutputData = pOutputData; + *(pOutputData++) = InputScalars; + for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) { + *(pOutputData++) = OutputScalars[streamIndex]; + pOutputData = CopyViewIDStateForOutputFromPSV( + pOutputData, InputScalars, OutputScalars[streamIndex], + PSV.GetViewIDOutputMask(streamIndex), + PSV.GetInputToOutputTable(streamIndex)); + } + if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) { + *(pOutputData++) = PCScalars; + pOutputData = CopyViewIDStateForOutputFromPSV( + pOutputData, InputScalars, PCScalars, PSV.GetViewIDPCOutputMask(), + PSV.GetInputToPCOutputTable()); + } else if (SK == PSVShaderKind::Domain) { + *(pOutputData++) = PCScalars; + pOutputData = CopyViewIDStateForOutputFromPSV( + pOutputData, PCScalars, OutputScalars[0], PSVComponentMask(), + PSV.GetPCInputToOutputTable()); + } + DXASSERT(pOutputData - pStartOutputData == OutputSizeInUInts, + "otherwise, OutputSizeInUInts didn't match size written."); + return pOutputData - pStartOutputData; +} + + +////////////////////////////////////////////////////////// +// DxilPSVWriter - Writes PSV0 part + class DxilPSVWriter : public DxilPartWriter { private: const DxilModule &m_Module; @@ -509,22 +683,6 @@ class DxilPSVWriter : public DxilPartWriter { E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF; } - const uint32_t *CopyViewIDState(const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars, PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) { - unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4); - if (ViewIDMask.IsValid()) { - DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors); - memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords); - pSrc += MaskDwords; - } - if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) { - DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4)); - DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4)); - memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars); - pSrc += MaskDwords * InputScalars; - } - return pSrc; - } - public: DxilPSVWriter(const DxilModule &mod, uint32_t PSVVersion = UINT_MAX) : m_Module(mod), @@ -840,23 +998,7 @@ class DxilPSVWriter : public DxilPartWriter { // Gather ViewID dependency information auto &viewState = m_Module.GetSerializedViewIdState(); if (!viewState.empty()) { - const uint32_t *pSrc = viewState.data(); - const uint32_t InputScalars = *(pSrc++); - uint32_t OutputScalars[4]; - for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) { - OutputScalars[streamIndex] = *(pSrc++); - pSrc = CopyViewIDState(pSrc, InputScalars, OutputScalars[streamIndex], m_PSV.GetViewIDOutputMask(streamIndex), m_PSV.GetInputToOutputTable(streamIndex)); - if (!SM->IsGS()) - break; - } - if (SM->IsHS() || SM->IsMS()) { - const uint32_t PCScalars = *(pSrc++); - pSrc = CopyViewIDState(pSrc, InputScalars, PCScalars, m_PSV.GetViewIDPCOutputMask(), m_PSV.GetInputToPCOutputTable()); - } else if (SM->IsDS()) { - const uint32_t PCScalars = *(pSrc++); - pSrc = CopyViewIDState(pSrc, PCScalars, OutputScalars[0], PSVComponentMask(), m_PSV.GetPCInputToOutputTable()); - } - DXASSERT_NOMSG(viewState.data() + viewState.size() == pSrc); + StoreViewIDStateToPSV(viewState.data(), (unsigned)viewState.size(), m_PSV); } } diff --git a/lib/HLSL/DxcOptimizer.cpp b/lib/HLSL/DxcOptimizer.cpp index 1b23d05ce9..dbc258a5df 100644 --- a/lib/HLSL/DxcOptimizer.cpp +++ b/lib/HLSL/DxcOptimizer.cpp @@ -14,6 +14,7 @@ #include "dxc/Support/Unicode.h" #include "dxc/Support/microcom.h" #include "dxc/DxilContainer/DxilContainer.h" +#include "dxc/DxilContainer/DxilContainerAssembler.h" #include "dxc/Support/FileIOHelper.h" #include "dxc/DXIL/DxilModule.h" #include "llvm/Analysis/ReducibilityAnalysis.h" @@ -23,6 +24,9 @@ #include "llvm/Analysis/DxilValueCache.h" #include "dxc/DXIL/DxilUtil.h" #include "dxc/Support/dxcapi.impl.h" +#include "dxc/DxilContainer/DxilRuntimeReflection.h" +#include "dxc/DxilContainer/DxilPipelineStateValidation.h" +#include "dxc/DxilContainer/DxilContainerAssembler.h" #include "llvm/Pass.h" #include "llvm/PassInfo.h" @@ -234,43 +238,127 @@ HRESULT STDMETHODCALLTYPE DxcOptimizer::RunOptimizer( DxcThreadMalloc TM(m_pMalloc); - // Setup input buffer. - // - // The ir parsing requires the buffer to be null terminated. We deal with - // both source and bitcode input, so the input buffer may not be null - // terminated; we create a new membuf that copies and appends for this. - // - // If we have the beginning of a DXIL program header, skip to the bitcode. - // - LLVMContext Context; - SMDiagnostic Err; - std::unique_ptr memBuf; - std::unique_ptr M; - const char * pBlobContent = reinterpret_cast(pBlob->GetBufferPointer()); - unsigned blobSize = pBlob->GetBufferSize(); - const DxilProgramHeader *pProgramHeader = - reinterpret_cast(pBlobContent); - if (IsValidDxilProgramHeader(pProgramHeader, blobSize)) { - std::string DiagStr; - GetDxilProgramBitcode(pProgramHeader, &pBlobContent, &blobSize); - M = hlsl::dxilutil::LoadModuleFromBitcode( - llvm::StringRef(pBlobContent, blobSize), Context, DiagStr); - } - else { - StringRef bufStrRef(pBlobContent, blobSize); - memBuf = MemoryBuffer::getMemBufferCopy(bufStrRef); - M = parseIR(memBuf->getMemBufferRef(), Err, Context); - } + try { - if (M == nullptr) { - return DXC_E_IR_VERIFICATION_FAILED; - } + // Setup input buffer. + // + // The ir parsing requires the buffer to be null terminated. We deal with + // both source and bitcode input, so the input buffer may not be null + // terminated; we create a new membuf that copies and appends for this. + // + // If we have the beginning of a DXIL program header, skip to the bitcode. + // - legacy::PassManager ModulePasses; - legacy::FunctionPassManager FunctionPasses(M.get()); - legacy::PassManagerBase *pPassManager = &ModulePasses; + LLVMContext Context; + SMDiagnostic Err; + std::unique_ptr memBuf; + std::unique_ptr M; + const char * pBlobContent = reinterpret_cast(pBlob->GetBufferPointer()); + unsigned blobSize = pBlob->GetBufferSize(); + const DxilProgramHeader *pProgramHeader = + reinterpret_cast(pBlobContent); + const DxilContainerHeader *pContainerHeader = IsDxilContainerLike(pBlobContent, blobSize); + bool bIsFullContainer = IsValidDxilContainer(pContainerHeader, blobSize); + + if (bIsFullContainer) { + // Prefer debug module, if present. + pProgramHeader = GetDxilProgramHeader(pContainerHeader, DFCC_ShaderDebugInfoDXIL); + if (!pProgramHeader) + pProgramHeader = GetDxilProgramHeader(pContainerHeader, DFCC_DXIL); + } + + if (IsValidDxilProgramHeader(pProgramHeader, blobSize)) { + std::string DiagStr; + GetDxilProgramBitcode(pProgramHeader, &pBlobContent, &blobSize); + M = hlsl::dxilutil::LoadModuleFromBitcode( + llvm::StringRef(pBlobContent, blobSize), Context, DiagStr); + } else if (!bIsFullContainer) { + StringRef bufStrRef(pBlobContent, blobSize); + memBuf = MemoryBuffer::getMemBufferCopy(bufStrRef); + M = parseIR(memBuf->getMemBufferRef(), Err, Context); + } else { + return DXC_E_CONTAINER_MISSING_DXIL; + } + + if (M == nullptr) { + return DXC_E_IR_VERIFICATION_FAILED; + } + + if (bIsFullContainer) { + // Restore extra data from certain parts back into the module so that data isn't lost. + // Note: Only GetOrCreateDxilModule if one of these is present. + // - Subobjects from RDAT + // - RootSignature from RTS0 + // - ViewID and I/O dependency data from PSV0 + // - Resource names and types/annotations from STAT + + // RDAT + if (const DxilPartHeader *pPartHeader = + GetDxilPartByType(pContainerHeader, DFCC_RuntimeData)) { + DxilModule &DM = M->GetOrCreateDxilModule(); + RDAT::DxilRuntimeData rdat(GetDxilPartData(pPartHeader), pPartHeader->PartSize); + auto table = rdat.GetSubobjectTable(); + if (table && table.Count() > 0) { + DM.ResetSubobjects(new DxilSubobjects()); + if (!LoadSubobjectsFromRDAT(*DM.GetSubobjects(), rdat)) { + return DXC_E_CONTAINER_INVALID; + } + } + } + + // RST0 + if (const DxilPartHeader *pPartHeader = + GetDxilPartByType(pContainerHeader, DFCC_RootSignature)) { + DxilModule &DM = M->GetOrCreateDxilModule(); + const uint8_t* pPartData = (const uint8_t*)GetDxilPartData(pPartHeader); + std::vector partData(pPartData, pPartData + pPartHeader->PartSize); + DM.ResetSerializedRootSignature(partData); + } + + // PSV0 + if (const DxilPartHeader *pPartHeader = GetDxilPartByType( + pContainerHeader, DFCC_PipelineStateValidation)) { + DxilModule &DM = M->GetOrCreateDxilModule(); + std::vector &viewState = DM.GetSerializedViewIdState(); + if (viewState.empty()) { + DxilPipelineStateValidation PSV; + PSV.InitFromPSV0(GetDxilPartData(pPartHeader), pPartHeader->PartSize); + unsigned OutputSizeInUInts = hlsl::LoadViewIDStateFromPSV(nullptr, 0, PSV); + if (OutputSizeInUInts) { + viewState.assign(OutputSizeInUInts, 0); + hlsl::LoadViewIDStateFromPSV(viewState.data(), (unsigned)viewState.size(), PSV); + } + } + } + + // STAT + if (const DxilPartHeader *pPartHeader = GetDxilPartByType( + pContainerHeader, DFCC_ShaderStatistics)) { + const DxilProgramHeader *pReflProgramHeader = + reinterpret_cast(GetDxilPartData(pPartHeader)); + if (IsValidDxilProgramHeader(pReflProgramHeader, + pPartHeader->PartSize)) { + const char *pReflBitcode; + uint32_t reflBitcodeLength; + GetDxilProgramBitcode((const DxilProgramHeader *)pReflProgramHeader, + &pReflBitcode, &reflBitcodeLength); + std::string DiagStr; + std::unique_ptr ReflM = hlsl::dxilutil::LoadModuleFromBitcode( + llvm::StringRef(pReflBitcode, reflBitcodeLength), Context, + DiagStr); + if (ReflM) { + // Restore resource names from reflection + M->GetOrCreateDxilModule().RestoreResourceReflection( + ReflM->GetOrCreateDxilModule()); + } + } + } + } + + legacy::PassManager ModulePasses; + legacy::FunctionPassManager FunctionPasses(M.get()); + legacy::PassManagerBase *pPassManager = &ModulePasses; - try { CComPtr pOutputStream; CComPtr pOutputBlob; diff --git a/tools/clang/unittests/HLSL/OptimizerTest.cpp b/tools/clang/unittests/HLSL/OptimizerTest.cpp index 71ee1b9f08..1d6fc0a617 100644 --- a/tools/clang/unittests/HLSL/OptimizerTest.cpp +++ b/tools/clang/unittests/HLSL/OptimizerTest.cpp @@ -20,7 +20,17 @@ #include #include #include + +// For DxilRuntimeReflection.h: +#include "dxc/Support/WinIncludes.h" + +#include "dxc/DXIL/DxilModule.h" +#include "dxc/DXIL/DxilUtil.h" + #include "dxc/DxilContainer/DxilContainer.h" +#include "dxc/DxilContainer/DxilPipelineStateValidation.h" +#include "dxc/DxilContainer/DxilRuntimeReflection.h" +#include "dxc/DxilRootSignature/DxilRootSignature.h" #include "dxc/Support/WinIncludes.h" #include "dxc/dxcapi.h" @@ -35,12 +45,16 @@ #include "dxc/Support/HLSLOptions.h" #include "dxc/Support/Unicode.h" +#include "llvm/Bitcode/ReaderWriter.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/MSFileSystem.h" #include "llvm/Support/Path.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringSwitch.h" + using namespace std; using namespace hlsl_test; @@ -67,9 +81,39 @@ class OptimizerTest : public ::testing::Test { TEST_METHOD(OptimizerWhenSlice3ThenOK) TEST_METHOD(OptimizerWhenSliceWithIntermediateOptionsThenOK) + TEST_METHOD(OptimizerWhenPassedContainerPreservesSubobjects) + TEST_METHOD(OptimizerWhenPassedContainerPreservesRootSig) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_HSDependentPCDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_HSDependentPCNonDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_HSNonDependentPCDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_HSNonDependentPCNonDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_MSDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_MSNonDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_DSDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_DSNonDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_VSDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_VSNonDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_GSDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesViewId_GSNonDependent) + TEST_METHOD(OptimizerWhenPassedContainerPreservesResourceStats_PSMultiCBTex2D) + void OptimizerWhenSliceNThenOK(int optLevel); void OptimizerWhenSliceNThenOK(int optLevel, LPCSTR pText, LPCWSTR pTarget, llvm::ArrayRef args = {}); - + CComPtr Compile( + char const *source, wchar_t const *entry, wchar_t const *profile); + CComPtr RunHlslEmitAndReturnContainer(IDxcBlob *container); + CComPtr RetrievePartFromContainer(IDxcBlob *container, + UINT32 part); + void ComparePSV0BeforeAndAfterOptimization(const char *source, + const wchar_t *entry, + const wchar_t *profile, + bool usesViewId, + int streamCount = 1); + void CompareSTATBeforeAndAfterOptimization(const char *source, + const wchar_t *entry, + const wchar_t *profile, + bool usesViewId, + int streamCount = 1); dxc::DxcDllSupport m_dllSupport; VersionSupportInfo m_ver; @@ -259,3 +303,1188 @@ void OptimizerTest::OptimizerWhenSliceNThenOK(int optLevel, LPCSTR pText, LPCWST } } } + +CComPtr OptimizerTest::Compile( + char const *source, wchar_t const *entry, wchar_t const *profile) { + + CComPtr pCompiler; + VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler)); + + CComPtr pSource; + Utf8ToBlob(m_dllSupport, source, &pSource); + + CComPtr pResult; + VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"source.hlsl", entry, profile, + nullptr, 0, nullptr, 0, nullptr, + &pResult)); + VerifyOperationSucceeded(pResult); + CComPtr pProgram; + VERIFY_SUCCEEDED(pResult->GetResult(&pProgram)); + return pProgram; +} + +CComPtr OptimizerTest::RunHlslEmitAndReturnContainer(IDxcBlob *container) +{ + CComPtr pOptimizer; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcOptimizer, &pOptimizer)); + + std::vector Options; + Options.push_back(L"-hlsl-dxilemit"); + + CComPtr pDxil; + CComPtr pText; + VERIFY_SUCCEEDED(pOptimizer->RunOptimizer(container, Options.data(), + Options.size(), &pDxil, &pText)); + + CComPtr pAssembler; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcAssembler, &pAssembler)); + + CComPtr result; + pAssembler->AssembleToContainer(pDxil, &result); + + CComPtr pOptimizedContainer; + result->GetResult(&pOptimizedContainer); + + return pOptimizedContainer; +} + +CComPtr OptimizerTest::RetrievePartFromContainer(IDxcBlob *container, + UINT32 part) { + CComPtr pReflection; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcContainerReflection, &pReflection)); + VERIFY_SUCCEEDED(pReflection->Load(container)); + UINT32 dxilIndex; + VERIFY_SUCCEEDED(pReflection->FindFirstPartKind(part, &dxilIndex)); + CComPtr blob; + VERIFY_SUCCEEDED(pReflection->GetPartContent(dxilIndex, & blob)); + return blob; +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesSubobjects) { + + const char *source = R"x( +GlobalRootSignature so_GlobalRootSignature = +{ + "RootConstants(num32BitConstants=1, b8), " +}; + +StateObjectConfig so_StateObjectConfig = +{ + STATE_OBJECT_FLAGS_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS +}; + +LocalRootSignature so_LocalRootSignature1 = +{ + "RootConstants(num32BitConstants=3, b2), " + "UAV(u6),RootFlags(LOCAL_ROOT_SIGNATURE)" +}; + +LocalRootSignature so_LocalRootSignature2 = +{ + "RootConstants(num32BitConstants=3, b2), " + "UAV(u8, flags=DATA_STATIC), " + "RootFlags(LOCAL_ROOT_SIGNATURE)" +}; + +RaytracingShaderConfig so_RaytracingShaderConfig = +{ + 128, // max payload size + 32 // max attribute size +}; + +RaytracingPipelineConfig so_RaytracingPipelineConfig = +{ + 2 // max trace recursion depth +}; + +TriangleHitGroup MyHitGroup = +{ + "MyAnyHit", // AnyHit + "MyClosestHit", // ClosestHit +}; + +SubobjectToExportsAssociation so_Association1 = +{ + "so_LocalRootSignature1", // subobject name + "MyRayGen" // export association +}; + +SubobjectToExportsAssociation so_Association2 = +{ + "so_LocalRootSignature2", // subobject name + "MyAnyHit" // export association +}; + +struct MyPayload +{ + float4 color; +}; + +[shader("raygeneration")] +void MyRayGen() +{ +} + +[shader("closesthit")] +void MyClosestHit(inout MyPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ +} + +[shader("anyhit")] +void MyAnyHit(inout MyPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ +} + +[shader("miss")] +void MyMiss(inout MyPayload payload) +{ +} + +)x"; + + auto pOptimizedContainer = + RunHlslEmitAndReturnContainer(Compile(source, L"", L"lib_6_6")); + + auto runtimeDataPart = RetrievePartFromContainer( + pOptimizedContainer, hlsl::DFCC_RuntimeData); + + hlsl::RDAT::DxilRuntimeData rdat(runtimeDataPart->GetBufferPointer(), + static_cast(runtimeDataPart->GetBufferSize())); + + auto const subObjectTableReader = rdat.GetSubobjectTable(); + + // There are 9 subobjects in the HLSL above: + VERIFY_ARE_EQUAL(subObjectTableReader.Count(), 9u); + for (uint32_t i = 0; i < subObjectTableReader.Count(); ++i) { + auto subObject = subObjectTableReader[i]; + hlsl::DXIL::SubobjectKind subobjectKind = subObject.getKind(); + switch (subobjectKind) { + case hlsl::DXIL::SubobjectKind::StateObjectConfig: + VERIFY_ARE_EQUAL_STR(subObject.getName(), "so_StateObjectConfig"); + break; + case hlsl::DXIL::SubobjectKind::GlobalRootSignature: + VERIFY_ARE_EQUAL_STR(subObject.getName(), "so_GlobalRootSignature"); + break; + case hlsl::DXIL::SubobjectKind::LocalRootSignature: + VERIFY_IS_TRUE( + 0 == strcmp(subObject.getName(), "so_LocalRootSignature1") || + 0 == strcmp(subObject.getName(), "so_LocalRootSignature2")); + break; + case hlsl::DXIL::SubobjectKind::SubobjectToExportsAssociation: + VERIFY_IS_TRUE( + 0 == strcmp(subObject.getName(), "so_Association1") || + 0 == strcmp(subObject.getName(), "so_Association2")); + break; + case hlsl::DXIL::SubobjectKind::RaytracingShaderConfig: + VERIFY_ARE_EQUAL_STR(subObject.getName(), "so_RaytracingShaderConfig"); + break; + case hlsl::DXIL::SubobjectKind::RaytracingPipelineConfig: + VERIFY_ARE_EQUAL_STR(subObject.getName(), "so_RaytracingPipelineConfig"); + break; + case hlsl::DXIL::SubobjectKind::HitGroup: + VERIFY_ARE_EQUAL_STR(subObject.getName(), "MyHitGroup"); + break; + break; + } + } +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesRootSig) { + const char * source = +R"( +#define RootSig \ + "UAV(u3, space=12), " \ + "RootConstants(num32BitConstants=1, b7, space = 42) " + [numthreads(1, 1, 1)] + [RootSignature(RootSig)] + void CSMain(uint3 dispatchThreadID : SV_DispatchThreadID) + { + } +)"; + + auto pOptimizedContainer = + RunHlslEmitAndReturnContainer(Compile(source, L"CSMain", L"cs_6_6")); + + auto rootSigPart = RetrievePartFromContainer( + pOptimizedContainer, hlsl::DFCC_RootSignature); + + hlsl::RootSignatureHandle RSH; + RSH.LoadSerialized(reinterpret_cast(rootSigPart->GetBufferPointer()), + static_cast(rootSigPart->GetBufferSize())); + + RSH.Deserialize(); + + auto const * desc = RSH.GetDesc(); + + VERIFY_ARE_EQUAL(desc->Version, hlsl::DxilRootSignatureVersion::Version_1_1); + VERIFY_ARE_EQUAL(desc->Desc_1_1.NumParameters, 2u); + for (unsigned int i = 0; i < desc->Desc_1_1.NumParameters; ++i) + { + hlsl::DxilRootParameter1 const *param = desc->Desc_1_1.pParameters + i; + switch (param->ParameterType) { + case hlsl::DxilRootParameterType::Constants32Bit: + VERIFY_ARE_EQUAL(param->Constants.Num32BitValues, 1u); + VERIFY_ARE_EQUAL(param->Constants.RegisterSpace, 42u); + VERIFY_ARE_EQUAL(param->Constants.ShaderRegister, 7u); + break; + case hlsl::DxilRootParameterType::UAV: + VERIFY_ARE_EQUAL(param->Descriptor.RegisterSpace, 12u); + VERIFY_ARE_EQUAL(param->Descriptor.ShaderRegister, 3u); + break; + default: + VERIFY_FAIL(L"Unexpected root param type"); + break; + } + } +} + +void VerifyIOTablesMatch(PSVDependencyTable original, + PSVDependencyTable optimized) { + VERIFY_ARE_EQUAL(original.IsValid(), optimized.IsValid()); + if (original.IsValid()) { + VERIFY_ARE_EQUAL(original.InputVectors, optimized.InputVectors); + VERIFY_ARE_EQUAL(original.OutputVectors, optimized.OutputVectors); + if (original.InputVectors == optimized.InputVectors && + original.OutputVectors == optimized.OutputVectors && + original.InputVectors * original.OutputVectors > 0) { + VERIFY_ARE_EQUAL( + 0, memcmp(original.Table, optimized.Table, + PSVComputeMaskDwordsFromVectors(original.OutputVectors) * + original.InputVectors * 4 * sizeof(uint32_t))); + } + } +} + +void OptimizerTest::ComparePSV0BeforeAndAfterOptimization(const char *source, const wchar_t *entry, const wchar_t *profile, bool usesViewId, int streamCount /*= 1*/) +{ + auto originalContainer = Compile(source, entry, profile); + + auto originalPsvPart = RetrievePartFromContainer( + originalContainer, hlsl::DFCC_PipelineStateValidation); + + auto optimizedContainer = RunHlslEmitAndReturnContainer(originalContainer); + + auto optimizedPsvPart = RetrievePartFromContainer(optimizedContainer, + hlsl::DFCC_PipelineStateValidation); + + VERIFY_ARE_EQUAL(originalPsvPart->GetBufferSize(), optimizedPsvPart->GetBufferSize()); + + VERIFY_ARE_EQUAL(memcmp(originalPsvPart->GetBufferPointer(), + optimizedPsvPart->GetBufferPointer(), + originalPsvPart->GetBufferSize()), + 0); + + DxilPipelineStateValidation originalPsv; + originalPsv.InitFromPSV0( + reinterpret_cast(optimizedPsvPart->GetBufferPointer()), + static_cast(optimizedPsvPart->GetBufferSize())); + + const PSVRuntimeInfo1 *originalInfo1 = originalPsv.GetPSVRuntimeInfo1(); + if (usesViewId) { + VERIFY_IS_TRUE(originalInfo1->UsesViewID); + } + + DxilPipelineStateValidation optimizedPsv; + optimizedPsv.InitFromPSV0( + reinterpret_cast(originalPsvPart->GetBufferPointer()), + static_cast(originalPsvPart->GetBufferSize())); + + const PSVRuntimeInfo1 *optimizedInfo1 = optimizedPsv.GetPSVRuntimeInfo1(); + + VERIFY_ARE_EQUAL(originalInfo1->ShaderStage, optimizedInfo1->ShaderStage); + VERIFY_ARE_EQUAL(originalInfo1->UsesViewID , optimizedInfo1->UsesViewID); + VERIFY_ARE_EQUAL(originalInfo1->SigInputElements, + optimizedInfo1->SigInputElements); + VERIFY_ARE_EQUAL(originalInfo1->SigOutputElements, + optimizedInfo1->SigOutputElements); + VERIFY_ARE_EQUAL(originalInfo1->SigPatchConstOrPrimElements, + optimizedInfo1->SigPatchConstOrPrimElements); + VERIFY_ARE_EQUAL(originalPsv.GetViewIDPCOutputMask().IsValid(), + optimizedPsv.GetViewIDPCOutputMask().IsValid()); + VERIFY_ARE_EQUAL(originalPsv.GetSigInputElements(), + optimizedPsv.GetSigInputElements()); + VERIFY_ARE_EQUAL(originalPsv.GetSigOutputElements(), + optimizedPsv.GetSigOutputElements()); + if (originalPsv.GetViewIDPCOutputMask().IsValid()) { + VERIFY_ARE_EQUAL(*originalPsv.GetViewIDPCOutputMask().Mask, + *optimizedPsv.GetViewIDPCOutputMask().Mask); + VERIFY_ARE_EQUAL(originalPsv.GetViewIDPCOutputMask().NumVectors, + optimizedPsv.GetViewIDPCOutputMask().NumVectors); + } + for (int stream = 0; stream < streamCount; ++stream) + { + VERIFY_ARE_EQUAL(originalInfo1->SigOutputVectors[stream], + optimizedInfo1->SigOutputVectors[stream]); + + VERIFY_ARE_EQUAL(originalPsv.GetViewIDOutputMask(stream).IsValid(), + optimizedPsv.GetViewIDOutputMask(stream).IsValid()); + + VerifyIOTablesMatch(originalPsv.GetInputToOutputTable(), + optimizedPsv.GetInputToOutputTable()); + VerifyIOTablesMatch(originalPsv.GetInputToPCOutputTable(), + optimizedPsv.GetInputToPCOutputTable()); + VerifyIOTablesMatch(originalPsv.GetPCInputToOutputTable(), + optimizedPsv.GetPCInputToOutputTable()); + + if (originalPsv.GetViewIDOutputMask(stream).IsValid()) + { + VERIFY_ARE_EQUAL(*originalPsv.GetViewIDOutputMask(stream).Mask, + *optimizedPsv.GetViewIDOutputMask(stream).Mask); + VERIFY_ARE_EQUAL(originalPsv.GetViewIDOutputMask(stream).NumVectors, + optimizedPsv.GetViewIDOutputMask(stream).NumVectors); + } + } +} + + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_HSDependentPCDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +#define NumOutPoints 2 + +struct HsCpIn { + int foo : FOO; +}; + +struct HsCpOut { + int bar : BAR; +}; + +struct HsPcfOut { + float tessOuter[4] : SV_TessFactor; + float tessInner[2] : SV_InsideTessFactor; +}; + +// Patch Constant Function +HsPcfOut pcf(uint viewid : SV_ViewID) { + HsPcfOut output; + output = (HsPcfOut)viewid; + return output; +} + +[domain("quad")] +[partitioning("fractional_odd")] +[outputtopology("triangle_ccw")] +[outputcontrolpoints(NumOutPoints)] +[patchconstantfunc("pcf")] +HsCpOut main(InputPatch patch, + uint id : SV_OutputControlPointID, + uint viewid : SV_ViewID) { + HsCpOut output; + output.bar = viewid; + return output; +} + +)", L"main", L"hs_6_6", true); + +} + +TEST_F(OptimizerTest, + OptimizerWhenPassedContainerPreservesViewId_HSNonDependentPCDependent) { + ComparePSV0BeforeAndAfterOptimization( + R"( +#define NumOutPoints 2 + +struct HsCpIn { + int foo : FOO; +}; + +struct HsCpOut { + int bar : BAR; +}; + +struct HsPcfOut { + float tessOuter[4] : SV_TessFactor; + float tessInner[2] : SV_InsideTessFactor; +}; + +// Patch Constant Function +HsPcfOut pcf(uint viewid : SV_ViewID) { + HsPcfOut output; + output = (HsPcfOut)viewid; + return output; +} + +[domain("quad")] +[partitioning("fractional_odd")] +[outputtopology("triangle_ccw")] +[outputcontrolpoints(NumOutPoints)] +[patchconstantfunc("pcf")] +HsCpOut main(InputPatch patch, + uint id : SV_OutputControlPointID, + uint viewid : SV_ViewID) { + HsCpOut output; + output.bar = 0; + return output; +} + +)", + L"main", L"hs_6_6", true); +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_HSDependentPCNonDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +#define NumOutPoints 2 + +struct HsCpIn { + int foo : FOO; +}; + +struct HsCpOut { + int bar : BAR; +}; + +struct HsPcfOut { + float tessOuter[4] : SV_TessFactor; + float tessInner[2] : SV_InsideTessFactor; +}; + +// Patch Constant Function +HsPcfOut pcf(uint viewid : SV_ViewID) { + HsPcfOut output; + output = (HsPcfOut)0; + return output; +} + +[domain("quad")] +[partitioning("fractional_odd")] +[outputtopology("triangle_ccw")] +[outputcontrolpoints(NumOutPoints)] +[patchconstantfunc("pcf")] +HsCpOut main(InputPatch patch, + uint id : SV_OutputControlPointID, + uint viewid : SV_ViewID) { + HsCpOut output; + output.bar = viewid; + return output; +} + +)", L"main", L"hs_6_6", true); + +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_HSNonDependentPCNonDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +#define NumOutPoints 2 + +struct HsCpIn { + int foo : FOO; +}; + +struct HsCpOut { + int bar : BAR; +}; + +struct HsPcfOut { + float tessOuter[4] : SV_TessFactor; + float tessInner[2] : SV_InsideTessFactor; +}; + +// Patch Constant Function +HsPcfOut pcf(uint viewid : SV_ViewID) { + HsPcfOut output; + output = (HsPcfOut)0; + return output; +} + +[domain("quad")] +[partitioning("fractional_odd")] +[outputtopology("triangle_ccw")] +[outputcontrolpoints(NumOutPoints)] +[patchconstantfunc("pcf")] +HsCpOut main(InputPatch patch, + uint id : SV_OutputControlPointID, + uint viewid : SV_ViewID) { + HsCpOut output; + output.bar = 0; + return output; +} + +)", + L"main", L"hs_6_6", false /*does not use view id*/); + +} + + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_MSDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +#define MAX_VERT 32 +#define MAX_PRIM 16 +#define NUM_THREADS 32 +struct MeshPerVertex { + float4 position : SV_Position; + float color[4] : COLOR; +}; + +struct MeshPerPrimitive { + float normal : NORMAL; + float malnor : MALNOR; + float alnorm : ALNORM; + float ormaln : ORMALN; + int layer[6] : LAYER; +}; + +struct MeshPayload { + float normal; + float malnor; + float alnorm; + float ormaln; + int layer[6]; +}; + +groupshared float gsMem[MAX_PRIM]; + +[numthreads(NUM_THREADS, 1, 1)] +[outputtopology("triangle")] +void main( + out indices uint3 primIndices[MAX_PRIM], + out vertices MeshPerVertex verts[MAX_VERT], + out primitives MeshPerPrimitive prims[MAX_PRIM], + in payload MeshPayload mpl, + in uint tig : SV_GroupIndex, + in uint vid : SV_ViewID + ) +{ + SetMeshOutputCounts(MAX_VERT, MAX_PRIM); + MeshPerVertex ov; + if (vid % 2) { + ov.position = float4(4.0,5.0,6.0,7.0); + ov.color[0] = 4.0; + ov.color[1] = 5.0; + ov.color[2] = 6.0; + ov.color[3] = 7.0; + } else { + ov.position = float4(14.0,15.0,16.0,17.0); + ov.color[0] = 14.0; + ov.color[1] = 15.0; + ov.color[2] = 16.0; + ov.color[3] = 17.0; + } + if (tig % 3) { + primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2); + MeshPerPrimitive op; + op.normal = mpl.normal; + op.malnor = gsMem[tig / 3 + 1]; + op.alnorm = mpl.alnorm; + op.ormaln = mpl.ormaln; + op.layer[0] = mpl.layer[0]; + op.layer[1] = mpl.layer[1]; + op.layer[2] = mpl.layer[2]; + op.layer[3] = mpl.layer[3]; + op.layer[4] = mpl.layer[4]; + op.layer[5] = mpl.layer[5]; + gsMem[tig / 3] = op.normal; + prims[tig / 3] = op; + } + verts[tig] = ov; +} +)", + L"main", L"ms_6_5", true /*does use view id*/); + +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_MSNonDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +#define MAX_VERT 32 +#define MAX_PRIM 16 +#define NUM_THREADS 32 +struct MeshPerVertex { + float4 position : SV_Position; + float color[4] : COLOR; +}; + +struct MeshPerPrimitive { + float normal : NORMAL; + float malnor : MALNOR; + float alnorm : ALNORM; + float ormaln : ORMALN; + int layer[6] : LAYER; +}; + +struct MeshPayload { + float normal; + float malnor; + float alnorm; + float ormaln; + int layer[6]; +}; + +groupshared float gsMem[MAX_PRIM]; + +[numthreads(NUM_THREADS, 1, 1)] +[outputtopology("triangle")] +void main( + out indices uint3 primIndices[MAX_PRIM], + out vertices MeshPerVertex verts[MAX_VERT], + out primitives MeshPerPrimitive prims[MAX_PRIM], + in payload MeshPayload mpl, + in uint tig : SV_GroupIndex, + in uint vid : SV_ViewID + ) +{ + SetMeshOutputCounts(MAX_VERT, MAX_PRIM); + MeshPerVertex ov; + if (false) { + ov.position = float4(4.0,5.0,6.0,7.0); + ov.color[0] = 4.0; + ov.color[1] = 5.0; + ov.color[2] = 6.0; + ov.color[3] = 7.0; + } else { + ov.position = float4(14.0,15.0,16.0,17.0); + ov.color[0] = 14.0; + ov.color[1] = 15.0; + ov.color[2] = 16.0; + ov.color[3] = 17.0; + } + if (tig % 3) { + primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2); + MeshPerPrimitive op; + op.normal = mpl.normal; + op.malnor = gsMem[tig / 3 + 1]; + op.alnorm = mpl.alnorm; + op.ormaln = mpl.ormaln; + op.layer[0] = mpl.layer[0]; + op.layer[1] = mpl.layer[1]; + op.layer[2] = mpl.layer[2]; + op.layer[3] = mpl.layer[3]; + op.layer[4] = mpl.layer[4]; + op.layer[5] = mpl.layer[5]; + gsMem[tig / 3] = op.normal; + prims[tig / 3] = op; + } + verts[tig] = ov; +} +)", + L"main", L"ms_6_5", false /*does not use view id*/); + +} + + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_DSDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +struct PSInput +{ + float4 position : SV_POSITION; + float4 color : COLOR; +}; + +struct HS_CONSTANT_DATA_OUTPUT +{ + float Edges[3] : SV_TessFactor; + float Inside : SV_InsideTessFactor; +}; + +struct BEZIER_CONTROL_POINT +{ + float3 vPosition : BEZIERPOS; + float4 color : COLOR; +}; + +[domain("tri")] +PSInput DSMain(HS_CONSTANT_DATA_OUTPUT input, + float3 UV : SV_DomainLocation, + uint viewID : SV_ViewID, + const OutputPatch bezpatch) +{ + PSInput Output; + + Output.position = float4( + bezpatch[0].vPosition * UV.x + viewID + + bezpatch[1].vPosition * UV.y + + bezpatch[2].vPosition * UV.z + ,1); + Output.color = float4( + bezpatch[0].color * UV.x + + bezpatch[1].color * UV.y + + bezpatch[2].color * UV.z + ); + return Output; +} + +)", + L"DSMain", L"ds_6_5", true /*does use view id*/); + +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_DSNonDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +struct PSInput +{ + float4 position : SV_POSITION; + float4 color : COLOR; +}; + +struct HS_CONSTANT_DATA_OUTPUT +{ + float Edges[3] : SV_TessFactor; + float Inside : SV_InsideTessFactor; +}; + +struct BEZIER_CONTROL_POINT +{ + float3 vPosition : BEZIERPOS; + float4 color : COLOR; +}; + +[domain("tri")] +PSInput DSMain(HS_CONSTANT_DATA_OUTPUT input, + float3 UV : SV_DomainLocation, + uint viewID : SV_ViewID, + const OutputPatch bezpatch) +{ + PSInput Output; + + Output.position = float4( + bezpatch[0].vPosition * UV.x + + bezpatch[1].vPosition * UV.y + + bezpatch[2].vPosition * UV.z + ,1); + Output.color = float4( + bezpatch[0].color * UV.x + + bezpatch[1].color * UV.y + + bezpatch[2].color * UV.z + ); + return Output; +} + +)", + L"DSMain", L"ds_6_5", false /*does not use view id*/); + +} + + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_VSDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +struct VertexShaderInput +{ + float3 pos : POSITION; + float3 color : COLOR0; + uint viewID : SV_ViewID; +}; +struct VertexShaderOutput +{ + float4 pos : SV_POSITION; + float3 color : COLOR0; +}; + +VertexShaderOutput main(VertexShaderInput input) +{ + VertexShaderOutput output; + output.pos = float4(input.pos, 1.0f); + if (input.viewID % 2) { + output.color = float3(1.0f, 0.0f, 1.0f); + } else { + output.color = float3(0.0f, 1.0f, 0.0f); + } + return output; +} + +)", + L"main", L"vs_6_5", true /*does use view id*/); + +} + + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_VSNonDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( +struct VertexShaderInput +{ + float3 pos : POSITION; + float3 color : COLOR0; + uint viewID : SV_ViewID; +}; +struct VertexShaderOutput +{ + float4 pos : SV_POSITION; + float3 color : COLOR0; +}; + +VertexShaderOutput main(VertexShaderInput input) +{ + VertexShaderOutput output; + output.pos = float4(input.pos, 1.0f); + if (false) { + output.color = float3(1.0f, 0.0f, 1.0f); + } else { + output.color = float3(0.0f, 1.0f, 0.0f); + } + return output; +} + +)", + L"main", L"vs_6_5", false /*does not use view id*/); + +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_GSNonDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( + + +struct MyStructIn +{ +// float4 pos : SV_Position; + float4 a : AAA; + float2 b : BBB; + float4 c[3] : CCC; + //uint d : SV_RenderTargetIndex; + float4 pos : SV_Position; +}; + +struct MyStructOut +{ + float4 pos : SV_Position; + float2 out_a : OUT_AAA; + uint d : SV_RenderTargetArrayIndex; +}; + +[maxvertexcount(18)] +void main(triangleadj MyStructIn array[6], inout TriangleStream OutputStream0) +{ + float4 r = array[1].a + array[2].b.x + array[3].pos; + r += array[r.x].c[r.y].w; + MyStructOut output = (MyStructOut)0; + output.pos = array[r.x].a; + output.out_a = array[r.y].b; + output.d = array[r.x].a + 3; + OutputStream0.Append(output); + OutputStream0.RestartStrip(); +} + + +)", + L"main", L"gs_6_5", false /*does not use view id*/); + +} + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesViewId_GSDependent) +{ + ComparePSV0BeforeAndAfterOptimization( + R"( + + +struct MyStructIn +{ +// float4 pos : SV_Position; + float4 a : AAA; + float2 b : BBB; + float4 c[3] : CCC; + float4 pos : SV_Position; +}; + +struct MyStructOut +{ + float4 pos : SV_Position; + float2 out_a : OUT_AAA; + uint d : SV_RenderTargetArrayIndex; +}; + +[maxvertexcount(18)] +void main(triangleadj MyStructIn array[6], inout TriangleStream OutputStream0, uint viewid : SV_ViewID) +{ + float4 r = array[1].a + array[2].b.x + array[3].pos; + r += array[r.x].c[r.y].w + viewid; + MyStructOut output = (MyStructOut)0; + output.pos = array[r.x].a; + output.out_a = array[r.y].b; + output.d = array[r.x].a + 3; + OutputStream0.Append(output); + OutputStream0.RestartStrip(); +} + + +)", + L"main", L"gs_6_5", true /*does use view id*/, 4); + +} + +using namespace llvm; +using namespace hlsl; + +static std::unique_ptr + GetDxilModuleFromStatsBlobInContainer(LLVMContext & Context, IDxcBlob *pBlob) { + const char *pBlobContent = + reinterpret_cast(pBlob->GetBufferPointer()); + unsigned blobSize = pBlob->GetBufferSize(); + const DxilContainerHeader *pContainerHeader = + IsDxilContainerLike(pBlobContent, blobSize); + const DxilPartHeader *pPartHeader = + GetDxilPartByType(pContainerHeader, DFCC_ShaderStatistics); + const DxilProgramHeader *pReflProgramHeader = + reinterpret_cast(GetDxilPartData(pPartHeader)); + (void)IsValidDxilProgramHeader(pReflProgramHeader, pPartHeader->PartSize); + const char *pReflBitcode; + uint32_t reflBitcodeLength; + GetDxilProgramBitcode((const DxilProgramHeader *)pReflProgramHeader, + &pReflBitcode, &reflBitcodeLength); + std::string DiagStr; + + return hlsl::dxilutil::LoadModuleFromBitcode( + llvm::StringRef(pReflBitcode, reflBitcodeLength), Context, DiagStr); +} + +template +void CompareResources(const vector> &original, + const vector> &optimized) { + + VERIFY_ARE_EQUAL(original.size(), optimized.size()); + for (size_t i = 0; i < original.size(); ++i) { + auto const &originalRes = original.at(i); + auto const &optimizedRes = optimized.at(i); + VERIFY_ARE_EQUAL(originalRes->GetClass(), optimizedRes->GetClass()); + VERIFY_ARE_EQUAL(originalRes->GetGlobalName(), optimizedRes->GetGlobalName()); + if (originalRes->GetGlobalSymbol() != nullptr && + optimizedRes->GetGlobalSymbol() != nullptr) { + VERIFY_ARE_EQUAL(originalRes->GetGlobalSymbol()->getName(), + optimizedRes->GetGlobalSymbol()->getName()); + } + if (originalRes->GetHLSLType() != nullptr && + optimizedRes->GetHLSLType() != nullptr) { + VERIFY_ARE_EQUAL(originalRes->GetHLSLType()->getTypeID(), + optimizedRes->GetHLSLType()->getTypeID()); + if (originalRes->GetHLSLType()->getTypeID() == Type::PointerTyID) { + auto originalPointedType = originalRes->GetHLSLType()->getArrayElementType(); + auto optimizedPointedType = optimizedRes->GetHLSLType()->getArrayElementType(); + VERIFY_ARE_EQUAL(originalPointedType->getTypeID(), + optimizedPointedType->getTypeID()); + if (optimizedPointedType->getTypeID() == Type::StructTyID) { + VERIFY_ARE_EQUAL(originalPointedType->getStructName(), + optimizedPointedType->getStructName()); + } + } + } + VERIFY_ARE_EQUAL(originalRes->GetID(), optimizedRes->GetID()); + VERIFY_ARE_EQUAL(originalRes->GetKind(), optimizedRes->GetKind()); + VERIFY_ARE_EQUAL(originalRes->GetLowerBound(), optimizedRes->GetLowerBound()); + VERIFY_ARE_EQUAL(originalRes->GetRangeSize(), optimizedRes->GetRangeSize()); + VERIFY_ARE_EQUAL(originalRes->GetSpaceID(), optimizedRes->GetSpaceID()); + VERIFY_ARE_EQUAL(originalRes->GetUpperBound(), optimizedRes->GetUpperBound()); + } +} + +void OptimizerTest::CompareSTATBeforeAndAfterOptimization( + const char *source, const wchar_t *entry, const wchar_t *profile, + bool usesViewId, int streamCount /*= 1*/) { + + + auto originalContainer = Compile(source, entry, profile); + + auto optimizedContainer = RunHlslEmitAndReturnContainer(originalContainer); + + ::llvm::sys::fs::MSFileSystem *msfPtr; + IFT(CreateMSFileSystemForDisk(&msfPtr)); + std::unique_ptr<::llvm::sys::fs::MSFileSystem> msf(msfPtr); + + ::llvm::sys::fs::AutoPerThreadSystem pts(msf.get()); + + LLVMContext originalContext, optimizedContext; + auto originalStatModule = + GetDxilModuleFromStatsBlobInContainer(originalContext, originalContainer); + auto optimizedStatModule = + GetDxilModuleFromStatsBlobInContainer(optimizedContext, optimizedContainer); + + + auto &originalDM = originalStatModule->GetOrCreateDxilModule(); + auto & optimizeDM = optimizedStatModule->GetOrCreateDxilModule(); + + + CompareResources(originalDM.GetCBuffers(), optimizeDM.GetCBuffers()); + CompareResources(originalDM.GetSRVs(), optimizeDM.GetSRVs()); + CompareResources(originalDM.GetUAVs(), optimizeDM.GetUAVs()); + CompareResources(originalDM.GetSamplers(), optimizeDM.GetSamplers()); + +} + + +TEST_F(OptimizerTest, OptimizerWhenPassedContainerPreservesResourceStats_PSMultiCBTex2D) +{ + CompareSTATBeforeAndAfterOptimization( + R"( + +#define MOD4(x) ((x)&3) +#ifndef MAX_POINTS +#define MAX_POINTS 32 +#endif +#define MAX_BONE_MATRICES 80 + +//-------------------------------------------------------------------------------------- +// Textures +//-------------------------------------------------------------------------------------- +Texture2D g_txHeight : register( t0 ); // Height and Bump texture +Texture2D g_txDiffuse : register( t1 ); // Diffuse texture +Texture2D g_txSpecular : register( t2 ); // Specular texture + +//-------------------------------------------------------------------------------------- +// Samplers +//-------------------------------------------------------------------------------------- +SamplerState g_samLinear : register( s0 ); +SamplerState g_samPoint : register( s0 ); + +//-------------------------------------------------------------------------------------- +// Constant Buffers +//-------------------------------------------------------------------------------------- +cbuffer cbTangentStencilConstants : register( b0 ) +{ + float g_TanM[1024]; // Tangent patch stencils precomputed by the application + float g_fCi[16]; // Valence coefficients precomputed by the application +}; + +cbuffer cbPerMesh : register( b1 ) +{ + matrix g_mConstBoneWorld[MAX_BONE_MATRICES]; +}; + +cbuffer cbPerFrame : register( b2 ) +{ + matrix g_mViewProjection; + float3 g_vCameraPosWorld; + float g_fTessellationFactor; + float g_fDisplacementHeight; + float3 g_vSolidColor; +}; + +cbuffer cbPerSubset : register( b3 ) +{ + int g_iPatchStartIndex; +} + +//-------------------------------------------------------------------------------------- +Buffer g_ValencePrefixBuffer : register( t0 ); + +//-------------------------------------------------------------------------------------- +struct VS_CONTROL_POINT_OUTPUT +{ + float3 vPosition : WORLDPOS; + float2 vUV : TEXCOORD0; + float3 vTangent : TANGENT; +}; + +struct BEZIER_CONTROL_POINT +{ + float3 vPosition : BEZIERPOS; +}; + +struct PS_INPUT +{ + float3 vWorldPos : POSITION; + float3 vNormal : NORMAL; + float2 vUV : TEXCOORD; + float3 vTangent : TANGENT; + float3 vBiTangent : BITANGENT; +}; + + +//-------------------------------------------------------------------------------------- +// Smooth shading pixel shader section +//-------------------------------------------------------------------------------------- + +float3 safe_normalize( float3 vInput ) +{ + float len2 = dot( vInput, vInput ); + if( len2 > 0 ) + { + return vInput * rsqrt( len2 ); + } + return vInput; +} + +static const float g_fSpecularExponent = 32.0f; +static const float g_fSpecularIntensity = 0.6f; +static const float g_fNormalMapIntensity = 1.5f; + +float2 ComputeDirectionalLight( float3 vWorldPos, float3 vWorldNormal, float3 vDirLightDir ) +{ + // Result.x is diffuse illumination, Result.y is specular illumination + float2 Result = float2( 0, 0 ); + Result.x = pow( saturate( dot( vWorldNormal, -vDirLightDir ) ), 2 ); + + float3 vPointToCamera = normalize( g_vCameraPosWorld - vWorldPos ); + float3 vHalfAngle = normalize( vPointToCamera - vDirLightDir ); + Result.y = pow( saturate( dot( vHalfAngle, vWorldNormal ) ), g_fSpecularExponent ); + + return Result; +} + +float3 ColorGamma( float3 Input ) +{ + return pow( Input, 2.2f ); +} + +float4 main( PS_INPUT Input ) : SV_TARGET +{ + float4 vNormalMapSampleRaw = g_txHeight.Sample( g_samLinear, Input.vUV ); + float3 vNormalMapSampleBiased = ( vNormalMapSampleRaw.xyz * 2 ) - 1; + vNormalMapSampleBiased.xy *= g_fNormalMapIntensity; + float3 vNormalMapSample = normalize( vNormalMapSampleBiased ); + + float3 vNormal = safe_normalize( Input.vNormal ) * vNormalMapSample.z; + vNormal += safe_normalize( Input.vTangent ) * vNormalMapSample.x; + vNormal += safe_normalize( Input.vBiTangent ) * vNormalMapSample.y; + + //float3 vColor = float3( 1, 1, 1 ); + float3 vColor = g_txDiffuse.Sample( g_samLinear, Input.vUV ).rgb; + float vSpecular = g_txSpecular.Sample( g_samLinear, Input.vUV ).r * g_fSpecularIntensity; + + const float3 DirLightDirections[4] = + { + // key light + normalize( float3( -63.345150, -58.043934, 27.785097 ) ), + // fill light + normalize( float3( 23.652107, -17.391443, 54.972504 ) ), + // back light 1 + normalize( float3( 20.470509, -22.939510, -33.929531 ) ), + // back light 2 + normalize( float3( -31.003685, 24.242104, -41.352859 ) ), + }; + + const float3 DirLightColors[4] = + { + // key light + ColorGamma( float3( 1.0f, 0.964f, 0.706f ) * 1.0f ), + // fill light + ColorGamma( float3( 0.446f, 0.641f, 1.0f ) * 1.0f ), + // back light 1 + ColorGamma( float3( 1.0f, 0.862f, 0.419f ) * 1.0f ), + // back light 2 + ColorGamma( float3( 0.405f, 0.630f, 1.0f ) * 1.0f ), + }; + + float3 fLightColor = 0; + for( int i = 0; i < 4; ++i ) + { + float2 LightDiffuseSpecular = ComputeDirectionalLight( Input.vWorldPos, vNormal, DirLightDirections[i] ); + fLightColor += DirLightColors[i] * vColor * LightDiffuseSpecular.x; + fLightColor += DirLightColors[i] * LightDiffuseSpecular.y * vSpecular; + } + + return float4( fLightColor, 1 ); +} + + +)", + L"main", L"ps_6_5", false /*does not use view id*/, 4); + +} + From 24c4be7b156fba2edbbf3d103d0f59ddc7c0915b Mon Sep 17 00:00:00 2001 From: Jeff Noyle Date: Mon, 12 Dec 2022 17:03:27 -0800 Subject: [PATCH 3/7] Fix hitgroup metadata argument order (cherry picked from commit 21cf36aad104929f600a04ecd37270f5dc2999b0) --- lib/DXIL/DxilMetadataHelper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/DXIL/DxilMetadataHelper.cpp b/lib/DXIL/DxilMetadataHelper.cpp index f10c9a616a..cff086f94f 100644 --- a/lib/DXIL/DxilMetadataHelper.cpp +++ b/lib/DXIL/DxilMetadataHelper.cpp @@ -1908,7 +1908,7 @@ Metadata *DxilMDHelper::EmitSubobject(const DxilSubobject &obj) { case DXIL::SubobjectKind::HitGroup: { llvm::StringRef Intersection, AnyHit, ClosestHit; DXIL::HitGroupType hgType; - IFTBOOL(obj.GetHitGroup(hgType, Intersection, AnyHit, ClosestHit), + IFTBOOL(obj.GetHitGroup(hgType, AnyHit, ClosestHit, Intersection), DXC_E_INCORRECT_DXIL_METADATA); Args.emplace_back(Uint32ToConstMD((uint32_t)hgType)); Args.emplace_back(MDString::get(m_Ctx, Intersection)); From 89b5b539b439fe99232666e52c122df51590f0b4 Mon Sep 17 00:00:00 2001 From: Joshua Batista Date: Thu, 15 Dec 2022 15:52:51 -0800 Subject: [PATCH 4/7] add barycentrics ordering check onto existing barycentrics test (#4635) * add barycentrics ordering check onto existing barycentrics test, with extra shader (cherry picked from commit ee0994e58eab2614a2244dfee868954727a89e9f) --- tools/clang/test/HLSL/ShaderOpArith.xml | 44 +++++-- tools/clang/unittests/HLSL/ExecutionTest.cpp | 119 ++++++++++++++----- 2 files changed, 128 insertions(+), 35 deletions(-) diff --git a/tools/clang/test/HLSL/ShaderOpArith.xml b/tools/clang/test/HLSL/ShaderOpArith.xml index 4eef3d085b..277e8cf8fa 100644 --- a/tools/clang/test/HLSL/ShaderOpArith.xml +++ b/tools/clang/test/HLSL/ShaderOpArith.xml @@ -1493,12 +1493,9 @@ RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT) - - { { 0.0f, 1.0f , 0.0f }, { 1.0f, 0.0f, 0.0f, 1.0f } }, - { { 1.0f, -1.0f , 0.0f }, { 0.0f, 1.0f, 0.0f, 1.0f } }, - { { -1.0f, -1.0f , 0.0f }, { 0.0f, 0.0f, 1.0f, 1.0f } } - - + + + @@ -1535,7 +1532,40 @@ float4 vColor2 = GetAttributeAtVertex(input.color, 2); return bary.x * vColor0 + bary.y * vColor1 + bary.z * vColor2; } - ]]> + ]]> + + + + + + + + diff --git a/tools/clang/unittests/HLSL/ExecutionTest.cpp b/tools/clang/unittests/HLSL/ExecutionTest.cpp index 3b8315645d..0390db0304 100644 --- a/tools/clang/unittests/HLSL/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSL/ExecutionTest.cpp @@ -8988,22 +8988,7 @@ TEST_F(ExecutionTest, CBufferTestHalf) { } } -TEST_F(ExecutionTest, BarycentricsTest) { - WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - CComPtr pStream; - ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream); - - CComPtr pDevice; - if (!CreateDevice(&pDevice, D3D_SHADER_MODEL_6_1)) - return; - - if (!DoesDeviceSupportBarycentrics(pDevice)) { - WEX::Logging::Log::Comment(L"Device does not support barycentrics."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - return; - } - - std::shared_ptr test = RunShaderOpTest(pDevice, m_support, pStream, "Barycentrics", nullptr); +void TestBarycentricVariant(bool checkOrdering, std::shared_ptr test){ MappedData data; D3D12_RESOURCE_DESC &D = test->ShaderOp->GetResourceByName("RTarget")->Desc; UINT width = (UINT)D.Width; @@ -9011,7 +8996,7 @@ TEST_F(ExecutionTest, BarycentricsTest) { UINT pixelSize = GetByteSizeForFormat(D.Format); test->Test->GetReadBackData("RTarget", &data); - //const uint8_t *pPixels = (uint8_t *)data.data(); + const float *pPixels = (float *)data.data(); // Get the vertex of barycentric coordinate using VBuffer MappedData triangleData; @@ -9025,15 +9010,16 @@ TEST_F(ExecutionTest, BarycentricsTest) { XMFLOAT2 p0(pTriangleData[0], pTriangleData[1]); XMFLOAT2 p1(pTriangleData[triangleVertexSizeInFloat], pTriangleData[triangleVertexSizeInFloat + 1]); XMFLOAT2 p2(pTriangleData[triangleVertexSizeInFloat * 2], pTriangleData[triangleVertexSizeInFloat * 2 + 1]); - + + // Seems like the 3 floats must add up to 1 to get accurate results. XMFLOAT3 barycentricWeights[4] = { - XMFLOAT3(0.3333f, 0.3333f, 0.3333f), + XMFLOAT3(0.4f, 0.2f, 0.4f), XMFLOAT3(0.5f, 0.25f, 0.25f), XMFLOAT3(0.25f, 0.5f, 0.25f), XMFLOAT3(0.25f, 0.25f, 0.50f) - }; + }; - float tolerance = 0.001f; + float tolerance = 0.02f; for (unsigned i = 0; i < sizeof(barycentricWeights) / sizeof(XMFLOAT3); ++i) { float w0 = barycentricWeights[i].x; float w1 = barycentricWeights[i].y; @@ -9041,17 +9027,89 @@ TEST_F(ExecutionTest, BarycentricsTest) { float x1 = w0 * p0.x + w1 * p1.x + w2 * p2.x; float y1 = w0 * p0.y + w1 * p1.y + w2 * p2.y; // map from x1 y1 to rtv pixels - int pixelX = (int)((x1 + 1) * (width - 1) / 2); - int pixelY = (int)((1 - y1) * (height - 1) / 2); + int pixelX = (int)round((x1 + 1) * (width - 1) / 2.0); + int pixelY = (int)round((1 - y1) * (height - 1) / 2.0); int offset = pixelSize * (pixelX + pixelY * width) / sizeof(pPixels[0]); LogCommentFmt(L"location %u %u, value %f, %f, %f", pixelX, pixelY, pPixels[offset], pPixels[offset + 1], pPixels[offset + 2]); - VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset], w0, tolerance)); - VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 1], w1, tolerance)); - VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 2], w2, tolerance)); + if (!checkOrdering){ + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset], w0, tolerance)); + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 1], w1, tolerance)); + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 2], w2, tolerance)); + } + else{ + // If the ordering constraint is met, then this pixel's RGBA should be all 1.0's + // since the shader only returns float4<1.0,1.0,1.0,1.0> when this condition is met. + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset] , 0.0, tolerance)); + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 1], 0.5, tolerance)); + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 2], 1.0, tolerance)); + VERIFY_IS_TRUE(CompareFloatEpsilon(pPixels[offset + 3], 1.0, tolerance)); + } + } +} + +st::ShaderOpTest::TInitCallbackFn MakeBarycentricsResourceInitCallbackFn(int &vertexShift){ + return [&](LPCSTR Name, std::vector& Data, st::ShaderOp* pShaderOp) { + std::vector bary = { 0.0f, 1.0f , 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 1.0f, -1.0f , 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, + -1.0f, -1.0f , 0.0f, 0.0f, 0.0f, 1.0f, 1.0f }; + const int barysize = 21; + + UNREFERENCED_PARAMETER(pShaderOp); + VERIFY_IS_TRUE(0 == _stricmp(Name, "VBuffer")); + size_t size = sizeof(float) * barysize; + Data.resize(size); + float* vb = (float*)Data.data(); + for (size_t i = 0; i < barysize; ++i) { + float* p = &vb[i]; + float tempfloat = bary[(i + (7 * vertexShift)) % barysize]; + *p = tempfloat; + } + }; + +} + +TEST_F(ExecutionTest, BarycentricsTest) { + WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + CComPtr pStream; + ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream); + + CComPtr pDevice; + if (!CreateDevice(&pDevice, D3D_SHADER_MODEL_6_1)) + return; + + if (!DoesDeviceSupportBarycentrics(pDevice)) { + WEX::Logging::Log::Comment(L"Device does not support barycentrics."); + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); + return; } - //SavePixelsToFile(pPixels, DXGI_FORMAT_R32G32B32A32_FLOAT, width, height, L"barycentric.bmp"); + + DXASSERT_NOMSG(pStream != nullptr); + std::shared_ptr ShaderOpSet = + std::make_shared(); + st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get()); + st::ShaderOp* pShaderOp = + ShaderOpSet->GetShaderOp("Barycentrics"); + + int test_iteration = 0; + auto ResourceCallbackFnNoShift = MakeBarycentricsResourceInitCallbackFn(test_iteration); + + std::shared_ptr test = RunShaderOpTestAfterParse(pDevice, m_support, "Barycentrics", ResourceCallbackFnNoShift, ShaderOpSet); + TestBarycentricVariant(false, test); + + // Now test that barycentric ordering is consistent + LogCommentFmt(L"Now testing that the barycentric ordering constraint is upheld for each pixel..."); + pShaderOp->VS = pShaderOp->GetString("VSordering"); + pShaderOp->PS = pShaderOp->GetString("PSordering"); + for(; test_iteration < 3; test_iteration++) + { + auto ResourceCallbackFn = MakeBarycentricsResourceInitCallbackFn(test_iteration); + + std::shared_ptr test2 = RunShaderOpTestAfterParse(pDevice, m_support, "Barycentrics", ResourceCallbackFn, ShaderOpSet); + TestBarycentricVariant(true, test2); + } } + static const char RawBufferTestShaderDeclarations[] = "// Note: COMPONENT_TYPE and COMPONENT_SIZE will be defined via compiler option -D\r\n" "typedef COMPONENT_TYPE scalar; \r\n" @@ -11488,7 +11546,12 @@ TEST_F(ExecutionTest, IsNormalTest) { st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get()); st::ShaderOp *pShaderOp = ShaderOpSet->GetShaderOp("IsNormal"); vector fallbackRootValues = pShaderOp->RootValues; - + + D3D_SHADER_MODEL sm = D3D_SHADER_MODEL_6_0; + LogCommentFmt(L"\r\nVerifying isNormal in shader " + L"model 6.%1u", + ((UINT)sm & 0x0f)); + size_t count = Validation_Input->size(); auto ShaderInitFn = MakeShaderReplacementCallback( From b6a9ac06a2b301615bf6d51f4f6e9842fc79f264 Mon Sep 17 00:00:00 2001 From: Tex Riddell Date: Wed, 14 Dec 2022 17:40:03 -0800 Subject: [PATCH 5/7] ConvertFloat32ToFloat16: Use DirectXMath conversion functions (#4855) Custom half <-> float conversion functions had problems in multiple scenarios. This PR changes them into a wrapper, using the DirectXMath conversion functions instead. (cherry picked from commit 6acd11bfc31566820fef7a8faf79c726651a837d) --- include/dxc/Test/HlslTestUtils.h | 88 +-------------------- tools/clang/unittests/HLSL/ShaderOpTest.cpp | 8 ++ 2 files changed, 11 insertions(+), 85 deletions(-) diff --git a/include/dxc/Test/HlslTestUtils.h b/include/dxc/Test/HlslTestUtils.h index 7d15eccca1..548781a668 100644 --- a/include/dxc/Test/HlslTestUtils.h +++ b/include/dxc/Test/HlslTestUtils.h @@ -406,91 +406,9 @@ inline bool isnanFloat16(uint16_t val) { (val & FLOAT16_BIT_MANTISSA) != 0; } -inline uint16_t ConvertFloat32ToFloat16(float val) { - union Bits { - uint32_t u_bits; - float f_bits; - }; - - static const uint32_t SignMask = 0x8000; - - // Minimum f32 value representable in f16 format without denormalizing - static const uint32_t Min16in32 = 0x38800000; - - // Maximum f32 value (next to infinity) - static const uint32_t Max32 = 0x7f7FFFFF; - - // Mask for f32 mantissa - static const uint32_t Fraction32Mask = 0x007FFFFF; - - // pow(2,24) - static const uint32_t DenormalRatio = 0x4B800000; - - static const uint32_t NormalDelta = 0x38000000; - - Bits bits; - bits.f_bits = val; - uint32_t sign = bits.u_bits & (SignMask << 16); - Bits Abs; - Abs.u_bits = bits.u_bits ^ sign; - - bool isLessThanNormal = Abs.f_bits < *(const float*)&Min16in32; - bool isInfOrNaN = Abs.u_bits > Max32; - - if (isLessThanNormal) { - // Compute Denormal result - return (uint16_t)(Abs.f_bits * *(const float*)(&DenormalRatio)) | (uint16_t)(sign >> 16); - } - else if (isInfOrNaN) { - // Compute Inf or Nan result - uint32_t Fraction = Abs.u_bits & Fraction32Mask; - uint16_t IsNaN = Fraction == 0 ? 0 : 0xffff; - return (IsNaN & FLOAT16_BIT_MANTISSA) | FLOAT16_BIT_EXP | (uint16_t)(sign >> 16); - } - else { - // Compute Normal result - return (uint16_t)((Abs.u_bits - NormalDelta) >> 13) | (uint16_t)(sign >> 16); - } -} - -inline float ConvertFloat16ToFloat32(uint16_t x) { - union Bits { - float f_bits; - uint32_t u_bits; - }; - - uint32_t Sign = (x & FLOAT16_BIT_SIGN) << 16; - - // nan -> exponent all set and mantisa is non zero - // +/-inf -> exponent all set and mantissa is zero - // denorm -> exponent zero and significand nonzero - uint32_t Abs = (x & 0x7fff); - uint32_t IsNormal = Abs > FLOAT16_BIGGEST_DENORM; - uint32_t IsInfOrNaN = Abs > FLOAT16_BIGGEST_NORMAL; - - // Signless Result for normals - uint32_t DenormRatio = 0x33800000; - float DenormResult = Abs * (*(float*)&DenormRatio); - - uint32_t AbsShifted = Abs << 13; - // Signless Result for normals - uint32_t NormalResult = AbsShifted + 0x38000000; - // Signless Result for int & nans - uint32_t InfResult = AbsShifted + 0x70000000; - - Bits bits; - bits.u_bits = 0; - if (IsInfOrNaN) - bits.u_bits |= InfResult; - else if (IsNormal) - bits.u_bits |= NormalResult; - else - bits.f_bits = DenormResult; - bits.u_bits |= Sign; - return bits.f_bits; -} -uint16_t ConvertFloat32ToFloat16(float val); -float ConvertFloat16ToFloat32(uint16_t val); +// These are defined in ShaderOpTest.cpp using DirectXPackedVector functions. +uint16_t ConvertFloat32ToFloat16(float val) throw(); +float ConvertFloat16ToFloat32(uint16_t val) throw(); inline bool CompareFloatULP(const float &fsrc, const float &fref, int ULPTolerance, hlsl::DXIL::Float32DenormMode mode = hlsl::DXIL::Float32DenormMode::Any) { diff --git a/tools/clang/unittests/HLSL/ShaderOpTest.cpp b/tools/clang/unittests/HLSL/ShaderOpTest.cpp index 66fce18efe..7d19149245 100644 --- a/tools/clang/unittests/HLSL/ShaderOpTest.cpp +++ b/tools/clang/unittests/HLSL/ShaderOpTest.cpp @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -40,6 +41,13 @@ /////////////////////////////////////////////////////////////////////////////// // Useful helper functions. +uint16_t ConvertFloat32ToFloat16(float Value) throw() { + return DirectX::PackedVector::XMConvertFloatToHalf(Value); +} +float ConvertFloat16ToFloat32(uint16_t Value) throw() { + return DirectX::PackedVector::XMConvertHalfToFloat(Value); +} + static st::OutputStringFn g_OutputStrFn; static void * g_OutputStrFnCtx; From f2dee04b210eff2c13e20ac7d03aeae563ee586c Mon Sep 17 00:00:00 2001 From: Helena Kotas Date: Thu, 15 Dec 2022 07:02:02 -0800 Subject: [PATCH 6/7] Include TestConfig.h only if DEFAULT_TEST_DIR is not defined (#4884) TestConfig.h is not available in HLK test build. This change enables skipping of the include. (cherry picked from commit e7aac8e0f5a0ff3e3dbc16eaf7a69a5739a97265) --- include/dxc/Test/HlslTestUtils.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/dxc/Test/HlslTestUtils.h b/include/dxc/Test/HlslTestUtils.h index 548781a668..06c16c5cbc 100644 --- a/include/dxc/Test/HlslTestUtils.h +++ b/include/dxc/Test/HlslTestUtils.h @@ -26,9 +26,12 @@ #include "WEXAdapter.h" #endif #include "dxc/Support/Unicode.h" -#include "dxc/Test/TestConfig.h" #include "dxc/DXIL/DxilConstants.h" // DenormMode +#ifndef DEFAULT_TEST_DIR +#include "dxc/Test/TestConfig.h" +#endif + using namespace std; #ifndef HLSLDATAFILEPARAM From 1f1b73729d38a7894621f64fcd5d8fd193c4822f Mon Sep 17 00:00:00 2001 From: Helena Kotas Date: Fri, 16 Dec 2022 10:41:38 -0800 Subject: [PATCH 7/7] Do not include TestConfig.h for all HLK build (#4887) TestConfig.h is not available in HLK test build. The test library uses _HLK_CONF define to distinguish between Exec tests-only and HLK-only code. (cherry picked from commit 5decc4aa974fa2c60ee299e6aedcb6b20998da7f) --- include/dxc/Test/HlslTestUtils.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/dxc/Test/HlslTestUtils.h b/include/dxc/Test/HlslTestUtils.h index 06c16c5cbc..cc656a92ab 100644 --- a/include/dxc/Test/HlslTestUtils.h +++ b/include/dxc/Test/HlslTestUtils.h @@ -28,7 +28,9 @@ #include "dxc/Support/Unicode.h" #include "dxc/DXIL/DxilConstants.h" // DenormMode -#ifndef DEFAULT_TEST_DIR +#ifdef _HLK_CONF +#define DEFAULT_TEST_DIR "" +#else #include "dxc/Test/TestConfig.h" #endif