diff --git a/src/D3D12Hook.cpp b/src/D3D12Hook.cpp index baea8a8a..76e4ec51 100644 --- a/src/D3D12Hook.cpp +++ b/src/D3D12Hook.cpp @@ -2,12 +2,15 @@ #include #include #include +#include #include #include #include #include #include +#include +#include #include "REFramework.hpp" @@ -16,16 +19,111 @@ #include "D3D12Hook.hpp" static D3D12Hook* g_d3d12_hook = nullptr; +thread_local bool g_inside_d3d12_hook = false; D3D12Hook::~D3D12Hook() { unhook(); } +void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9) { + if (g_inside_d3d12_hook) { + spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x} (inside D3D12 hook)", (uintptr_t)_ReturnAddress()); + + auto& hook = D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook; + return hook->get_original()(rcx, rdx, r8, r9); + } + + std::scoped_lock _{g_framework->get_hook_monitor_mutex()}; + + spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x}", (uintptr_t)_ReturnAddress()); + + bool hook_was_nullptr = g_d3d12_hook == nullptr; + + if (g_d3d12_hook != nullptr) { + g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around + g_d3d12_hook->unhook(); // Removes all vtable hooks + } + + auto& hook = D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook; + const auto result = hook->get_original()(rcx, rdx, r8, r9); + + // Re-hooks present after the above function creates the swapchain + // This allows the hook to immediately still function + // rather than waiting on the hook monitor to notice the hook isn't working + if (!hook_was_nullptr) { + g_framework->hook_d3d12(); + } + + return result; +} + +void D3D12Hook::hook_streamline(HMODULE dlssg_module) try { + if (D3D12Hook::s_streamline.setup) { + return; + } + + std::scoped_lock _{D3D12Hook::s_streamline.hook_mutex}; + + if (D3D12Hook::s_streamline.setup) { + return; + } + + spdlog::info("[Streamline] Hooking Streamline"); + + if (dlssg_module == nullptr) { + dlssg_module = GetModuleHandleW(L"sl.dlss_g.dll"); + } + + if (dlssg_module == nullptr) { + spdlog::error("[Streamline] Failed to get sl.dlss_g.dll module handle"); + return; + } + + const auto str = utility::scan_string(dlssg_module, "linkSwapchainToCmdQueue"); + + if (!str) { + spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue"); + return; + } + + const auto str_ref = utility::scan_displacement_reference(dlssg_module, *str); + + if (!str_ref) { + spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue reference"); + return; + } + + const auto fn = utility::find_function_start_with_call(*str_ref); + + if (!fn) { + spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue function"); + return; + } + + D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook = std::make_unique(*fn, (uintptr_t)&Streamline::link_swapchain_to_cmd_queue); + + if (D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook->create()) { + spdlog::info("[Streamline] Hooked linkSwapchainToCmdQueue"); + } else { + spdlog::error("[Streamline] Failed to hook linkSwapchainToCmdQueue"); + } + + D3D12Hook::s_streamline.setup = true; +} catch(...) { + spdlog::error("[Streamline] Failed to hook Streamline"); +} + bool D3D12Hook::hook() { spdlog::info("Hooking D3D12"); g_d3d12_hook = this; + g_inside_d3d12_hook = true; + + utility::ScopeGuard guard{[]() { + g_inside_d3d12_hook = false; + }}; + IDXGISwapChain1* swap_chain1{ nullptr }; IDXGISwapChain3* swap_chain{ nullptr }; ID3D12Device* device{ nullptr }; @@ -330,7 +428,9 @@ bool D3D12Hook::hook() { return false; } - utility::ThreadSuspender suspender{}; + hook_streamline(); + + //utility::ThreadSuspender suspender{}; try { spdlog::info("Initializing hooks"); @@ -341,20 +441,20 @@ bool D3D12Hook::hook() { m_is_phase_1 = true; auto& present_fn = (*(void***)target_swapchain)[8]; // Present - m_present_hook = std::make_unique(&present_fn, (void*)&D3D12Hook::present); + m_present_hook = std::make_unique(&present_fn, &D3D12Hook::present); m_hooked = true; } catch (const std::exception& e) { spdlog::error("Failed to initialize hooks: {}", e.what()); m_hooked = false; } - suspender.resume(); + //suspender.resume(); - device->Release(); command_queue->Release(); - factory->Release(); swap_chain1->Release(); swap_chain->Release(); + device->Release(); + factory->Release(); if (hwnd) { ::DestroyWindow(hwnd); @@ -368,6 +468,8 @@ bool D3D12Hook::hook() { } bool D3D12Hook::unhook() { + std::scoped_lock _{g_framework->get_hook_monitor_mutex()}; + if (!m_hooked) { return true; } @@ -385,57 +487,67 @@ bool D3D12Hook::unhook() { thread_local int32_t g_present_depth = 0; -HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags) { +HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9) { std::scoped_lock _{g_framework->get_hook_monitor_mutex()}; auto d3d12 = g_d3d12_hook; - HWND swapchain_wnd{nullptr}; - swap_chain->GetHwnd(&swapchain_wnd); - decltype(D3D12Hook::present)* present_fn{nullptr}; - //if (d3d12->m_is_phase_1) { + if (d3d12->m_is_phase_1) { present_fn = d3d12->m_present_hook->get_original(); - /*} else { + } else { present_fn = d3d12->m_swapchain_hook->get_method(8); - }*/ + } + + HWND swapchain_wnd{nullptr}; + swap_chain->GetHwnd(&swapchain_wnd); if (d3d12->m_is_phase_1 && WindowFilter::get().is_filtered(swapchain_wnd)) { //present_fn = d3d12->m_present_hook->get_original(); - return present_fn(swap_chain, sync_interval, flags); + return present_fn(swap_chain, sync_interval, flags, r9); } if (!d3d12->m_is_phase_1 && swap_chain != d3d12->m_swapchain_hook->get_instance()) { - return present_fn(swap_chain, sync_interval, flags); + return present_fn(swap_chain, sync_interval, flags, r9); } if (d3d12->m_is_phase_1) { - //d3d12->m_present_hook.reset(); + // Remove the present hook, we will just rely on the vtable hook below + // because we don't want to cause any conflicts with other hooks + // vtable hooks are the least intrusive + // And doing a global pointer replacement seems to have + // conflicts with Streamline's hooks, causing unexplainable crashes + d3d12->m_present_hook.reset(); // vtable hook the swapchain instead of global hooking // this seems safer for whatever reason // if we globally hook the vtable pointers, it causes all sorts of weird conflicts with other hooks // dont hook present though via this hook so other hooks dont get confused d3d12->m_swapchain_hook = std::make_unique(swap_chain); - //d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present); + //d3d12->m_swapchain_hook->hook_method(2, (uintptr_t)&D3D12Hook::release); + d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present); d3d12->m_swapchain_hook->hook_method(13, (uintptr_t)&D3D12Hook::resize_buffers); d3d12->m_swapchain_hook->hook_method(14, (uintptr_t)&D3D12Hook::resize_target); d3d12->m_is_phase_1 = false; + + present_fn = d3d12->m_swapchain_hook->get_method(8); } d3d12->m_inside_present = true; d3d12->m_swap_chain = swap_chain; - swap_chain->GetDevice(IID_PPV_ARGS(&d3d12->m_device)); + { + Microsoft::WRL::ComPtr temp_device{}; + swap_chain->GetDevice(IID_PPV_ARGS(&temp_device)); + d3d12->m_device = temp_device.Get(); + } - if (d3d12->m_device != nullptr) { - if (d3d12->m_using_proton_swapchain) { - const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset); - d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset); - } else { - d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset); - } + if (d3d12->m_using_proton_swapchain) { + const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset); + d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset); + } else { + d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset); } if (d3d12->m_swapchain_0 == nullptr) { @@ -462,7 +574,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva spdlog::info("Attempting to call real present function"); ++g_present_depth; - const auto result = present_fn(swap_chain, sync_interval, flags); + const auto result = present_fn(swap_chain, sync_interval, flags, r9); --g_present_depth; if (result != S_OK) { @@ -485,7 +597,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva auto result = S_OK; if (!d3d12->m_ignore_next_present) { - result = present_fn(swap_chain, sync_interval, flags); + result = present_fn(swap_chain, sync_interval, flags, r9); if (result != S_OK) { spdlog::error("Present failed: {:x}", result); diff --git a/src/D3D12Hook.hpp b/src/D3D12Hook.hpp index 1a3315ac..a115df46 100644 --- a/src/D3D12Hook.hpp +++ b/src/D3D12Hook.hpp @@ -10,6 +10,7 @@ #include #include "utility/PointerHook.hpp" +#include "utility/FunctionHook.hpp" #include "utility/VtableHook.hpp" class D3D12Hook @@ -88,7 +89,7 @@ class D3D12Hook bool is_proton_swapchain() const { return m_using_proton_swapchain; } - + bool is_framegen_swapchain() const { return m_using_frame_generation_swapchain; } @@ -97,6 +98,8 @@ class D3D12Hook m_ignore_next_present = true; } + static void hook_streamline(HMODULE dlssg_module = nullptr); + protected: ID3D12Device4* m_device{ nullptr }; IDXGISwapChain3* m_swap_chain{ nullptr }; @@ -119,16 +122,27 @@ class D3D12Hook bool m_ignore_next_present{false}; std::unique_ptr m_present_hook{}; + //std::unique_ptr m_release_hook{}; std::unique_ptr m_swapchain_hook{}; //std::unique_ptr m_create_swap_chain_hook{}; + struct Streamline { + static void* link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9); + + std::unique_ptr link_swapchain_to_cmd_queue_hook{}; + std::mutex hook_mutex{}; + bool setup{ false }; + }; + + static inline Streamline s_streamline{}; + OnPresentFn m_on_present{ nullptr }; OnPresentFn m_on_post_present{ nullptr }; OnResizeBuffersFn m_on_resize_buffers{ nullptr }; OnResizeTargetFn m_on_resize_target{ nullptr }; //OnCreateSwapChainFn m_on_create_swap_chain{ nullptr }; - static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags); + static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9); static HRESULT WINAPI resize_buffers(IDXGISwapChain3* swap_chain, UINT buffer_count, UINT width, UINT height, DXGI_FORMAT new_format, UINT swap_chain_flags); static HRESULT WINAPI resize_target(IDXGISwapChain3* swap_chain, const DXGI_MODE_DESC* new_target_parameters); //static HRESULT WINAPI create_swap_chain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain); diff --git a/src/REFramework.cpp b/src/REFramework.cpp index 1c5847b0..7e4472c4 100644 --- a/src/REFramework.cpp +++ b/src/REFramework.cpp @@ -199,6 +199,12 @@ try { if (NotificationData->Loaded.BaseDllName != nullptr && NotificationData->Loaded.BaseDllName->Buffer != nullptr) { std::wstring base_dll_name = NotificationData->Loaded.BaseDllName->Buffer; spdlog::info("LdrRegisterDllNotification: Loaded: {}", utility::narrow(base_dll_name)); + + if (base_dll_name.find(L"sl.dlss_g.dll") != std::wstring::npos) { + spdlog::info("LdrRegisterDllNotification: Detected DLSS DLL loaded"); + + D3D12Hook::hook_streamline((HMODULE)NotificationData->Loaded.DllBase); + } } if (g_current_game_path && NotificationData->Loaded.FullDllName != nullptr && NotificationData->Loaded.FullDllName->Buffer != nullptr) { diff --git a/src/REFramework.hpp b/src/REFramework.hpp index c51dc813..aa77c8a0 100644 --- a/src/REFramework.hpp +++ b/src/REFramework.hpp @@ -144,9 +144,11 @@ class REFramework { void draw_ui(); void draw_about(); +public: bool hook_d3d11(); bool hook_d3d12(); +private: bool initialize(); bool initialize_game_data(); bool initialize_windows_message_hook();