Skip to content

Commit

Permalink
Merge branch 'praydog:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyhodge authored Jun 3, 2024
2 parents ae70e36 + 5dc5998 commit f4bc210
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 28 deletions.
164 changes: 138 additions & 26 deletions src/D3D12Hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
#include <future>
#include <unordered_set>
#include <stacktrace>
#include <wrl/client.h>

#include <spdlog/spdlog.h>
#include <utility/Thread.hpp>
#include <utility/Module.hpp>
#include <utility/String.hpp>
#include <utility/RTTI.hpp>
#include <utility/Scan.hpp>
#include <utility/ScopeGuard.hpp>

#include "REFramework.hpp"

Expand All @@ -16,16 +19,111 @@
#include "D3D12Hook.hpp"

static D3D12Hook* g_d3d12_hook = nullptr;
thread_local bool g_inside_d3d12_hook = false;

D3D12Hook::~D3D12Hook() {
unhook();
}

void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9) {
if (g_inside_d3d12_hook) {
spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x} (inside D3D12 hook)", (uintptr_t)_ReturnAddress());

auto& hook = D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook;
return hook->get_original<decltype(link_swapchain_to_cmd_queue)>()(rcx, rdx, r8, r9);
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x}", (uintptr_t)_ReturnAddress());

bool hook_was_nullptr = g_d3d12_hook == nullptr;

if (g_d3d12_hook != nullptr) {
g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around
g_d3d12_hook->unhook(); // Removes all vtable hooks
}

auto& hook = D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook;
const auto result = hook->get_original<decltype(link_swapchain_to_cmd_queue)>()(rcx, rdx, r8, r9);

// Re-hooks present after the above function creates the swapchain
// This allows the hook to immediately still function
// rather than waiting on the hook monitor to notice the hook isn't working
if (!hook_was_nullptr) {
g_framework->hook_d3d12();
}

return result;
}

void D3D12Hook::hook_streamline(HMODULE dlssg_module) try {
if (D3D12Hook::s_streamline.setup) {
return;
}

std::scoped_lock _{D3D12Hook::s_streamline.hook_mutex};

if (D3D12Hook::s_streamline.setup) {
return;
}

spdlog::info("[Streamline] Hooking Streamline");

if (dlssg_module == nullptr) {
dlssg_module = GetModuleHandleW(L"sl.dlss_g.dll");
}

if (dlssg_module == nullptr) {
spdlog::error("[Streamline] Failed to get sl.dlss_g.dll module handle");
return;
}

const auto str = utility::scan_string(dlssg_module, "linkSwapchainToCmdQueue");

if (!str) {
spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue");
return;
}

const auto str_ref = utility::scan_displacement_reference(dlssg_module, *str);

if (!str_ref) {
spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue reference");
return;
}

const auto fn = utility::find_function_start_with_call(*str_ref);

if (!fn) {
spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue function");
return;
}

D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook = std::make_unique<FunctionHook>(*fn, (uintptr_t)&Streamline::link_swapchain_to_cmd_queue);

if (D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook->create()) {
spdlog::info("[Streamline] Hooked linkSwapchainToCmdQueue");
} else {
spdlog::error("[Streamline] Failed to hook linkSwapchainToCmdQueue");
}

D3D12Hook::s_streamline.setup = true;
} catch(...) {
spdlog::error("[Streamline] Failed to hook Streamline");
}

bool D3D12Hook::hook() {
spdlog::info("Hooking D3D12");

g_d3d12_hook = this;

g_inside_d3d12_hook = true;

utility::ScopeGuard guard{[]() {
g_inside_d3d12_hook = false;
}};

IDXGISwapChain1* swap_chain1{ nullptr };
IDXGISwapChain3* swap_chain{ nullptr };
ID3D12Device* device{ nullptr };
Expand Down Expand Up @@ -330,7 +428,9 @@ bool D3D12Hook::hook() {
return false;
}

utility::ThreadSuspender suspender{};
hook_streamline();

//utility::ThreadSuspender suspender{};

try {
spdlog::info("Initializing hooks");
Expand All @@ -341,20 +441,20 @@ bool D3D12Hook::hook() {
m_is_phase_1 = true;

auto& present_fn = (*(void***)target_swapchain)[8]; // Present
m_present_hook = std::make_unique<PointerHook>(&present_fn, (void*)&D3D12Hook::present);
m_present_hook = std::make_unique<PointerHook>(&present_fn, &D3D12Hook::present);
m_hooked = true;
} catch (const std::exception& e) {
spdlog::error("Failed to initialize hooks: {}", e.what());
m_hooked = false;
}

suspender.resume();
//suspender.resume();

device->Release();
command_queue->Release();
factory->Release();
swap_chain1->Release();
swap_chain->Release();
device->Release();
factory->Release();

if (hwnd) {
::DestroyWindow(hwnd);
Expand All @@ -368,6 +468,8 @@ bool D3D12Hook::hook() {
}

bool D3D12Hook::unhook() {
std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

if (!m_hooked) {
return true;
}
Expand All @@ -385,57 +487,67 @@ bool D3D12Hook::unhook() {

thread_local int32_t g_present_depth = 0;

HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags) {
HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9) {
std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

auto d3d12 = g_d3d12_hook;

HWND swapchain_wnd{nullptr};
swap_chain->GetHwnd(&swapchain_wnd);

decltype(D3D12Hook::present)* present_fn{nullptr};

//if (d3d12->m_is_phase_1) {
if (d3d12->m_is_phase_1) {
present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
/*} else {
} else {
present_fn = d3d12->m_swapchain_hook->get_method<decltype(D3D12Hook::present)*>(8);
}*/
}

HWND swapchain_wnd{nullptr};
swap_chain->GetHwnd(&swapchain_wnd);

if (d3d12->m_is_phase_1 && WindowFilter::get().is_filtered(swapchain_wnd)) {
//present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
return present_fn(swap_chain, sync_interval, flags);
return present_fn(swap_chain, sync_interval, flags, r9);
}

if (!d3d12->m_is_phase_1 && swap_chain != d3d12->m_swapchain_hook->get_instance()) {
return present_fn(swap_chain, sync_interval, flags);
return present_fn(swap_chain, sync_interval, flags, r9);
}

if (d3d12->m_is_phase_1) {
//d3d12->m_present_hook.reset();
// Remove the present hook, we will just rely on the vtable hook below
// because we don't want to cause any conflicts with other hooks
// vtable hooks are the least intrusive
// And doing a global pointer replacement seems to have
// conflicts with Streamline's hooks, causing unexplainable crashes
d3d12->m_present_hook.reset();

// vtable hook the swapchain instead of global hooking
// this seems safer for whatever reason
// if we globally hook the vtable pointers, it causes all sorts of weird conflicts with other hooks
// dont hook present though via this hook so other hooks dont get confused
d3d12->m_swapchain_hook = std::make_unique<VtableHook>(swap_chain);
//d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present);
//d3d12->m_swapchain_hook->hook_method(2, (uintptr_t)&D3D12Hook::release);
d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present);
d3d12->m_swapchain_hook->hook_method(13, (uintptr_t)&D3D12Hook::resize_buffers);
d3d12->m_swapchain_hook->hook_method(14, (uintptr_t)&D3D12Hook::resize_target);
d3d12->m_is_phase_1 = false;

present_fn = d3d12->m_swapchain_hook->get_method<decltype(D3D12Hook::present)*>(8);
}

d3d12->m_inside_present = true;
d3d12->m_swap_chain = swap_chain;

swap_chain->GetDevice(IID_PPV_ARGS(&d3d12->m_device));
{
Microsoft::WRL::ComPtr<ID3D12Device4> temp_device{};
swap_chain->GetDevice(IID_PPV_ARGS(&temp_device));
d3d12->m_device = temp_device.Get();
}

if (d3d12->m_device != nullptr) {
if (d3d12->m_using_proton_swapchain) {
const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset);
d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset);
} else {
d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset);
}
if (d3d12->m_using_proton_swapchain) {
const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset);
d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset);
} else {
d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset);
}

if (d3d12->m_swapchain_0 == nullptr) {
Expand All @@ -462,7 +574,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva
spdlog::info("Attempting to call real present function");

++g_present_depth;
const auto result = present_fn(swap_chain, sync_interval, flags);
const auto result = present_fn(swap_chain, sync_interval, flags, r9);
--g_present_depth;

if (result != S_OK) {
Expand All @@ -485,7 +597,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva
auto result = S_OK;

if (!d3d12->m_ignore_next_present) {
result = present_fn(swap_chain, sync_interval, flags);
result = present_fn(swap_chain, sync_interval, flags, r9);

if (result != S_OK) {
spdlog::error("Present failed: {:x}", result);
Expand Down
18 changes: 16 additions & 2 deletions src/D3D12Hook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <dxgi1_4.h>

#include "utility/PointerHook.hpp"
#include "utility/FunctionHook.hpp"
#include "utility/VtableHook.hpp"

class D3D12Hook
Expand Down Expand Up @@ -88,7 +89,7 @@ class D3D12Hook
bool is_proton_swapchain() const {
return m_using_proton_swapchain;
}

bool is_framegen_swapchain() const {
return m_using_frame_generation_swapchain;
}
Expand All @@ -97,6 +98,8 @@ class D3D12Hook
m_ignore_next_present = true;
}

static void hook_streamline(HMODULE dlssg_module = nullptr);

protected:
ID3D12Device4* m_device{ nullptr };
IDXGISwapChain3* m_swap_chain{ nullptr };
Expand All @@ -119,16 +122,27 @@ class D3D12Hook
bool m_ignore_next_present{false};

std::unique_ptr<PointerHook> m_present_hook{};
//std::unique_ptr<PointerHook> m_release_hook{};
std::unique_ptr<VtableHook> m_swapchain_hook{};
//std::unique_ptr<FunctionHook> m_create_swap_chain_hook{};

struct Streamline {
static void* link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9);

std::unique_ptr<FunctionHook> link_swapchain_to_cmd_queue_hook{};
std::mutex hook_mutex{};
bool setup{ false };
};

static inline Streamline s_streamline{};

OnPresentFn m_on_present{ nullptr };
OnPresentFn m_on_post_present{ nullptr };
OnResizeBuffersFn m_on_resize_buffers{ nullptr };
OnResizeTargetFn m_on_resize_target{ nullptr };
//OnCreateSwapChainFn m_on_create_swap_chain{ nullptr };

static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags);
static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9);
static HRESULT WINAPI resize_buffers(IDXGISwapChain3* swap_chain, UINT buffer_count, UINT width, UINT height, DXGI_FORMAT new_format, UINT swap_chain_flags);
static HRESULT WINAPI resize_target(IDXGISwapChain3* swap_chain, const DXGI_MODE_DESC* new_target_parameters);
//static HRESULT WINAPI create_swap_chain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain);
Expand Down
6 changes: 6 additions & 0 deletions src/REFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ try {
if (NotificationData->Loaded.BaseDllName != nullptr && NotificationData->Loaded.BaseDllName->Buffer != nullptr) {
std::wstring base_dll_name = NotificationData->Loaded.BaseDllName->Buffer;
spdlog::info("LdrRegisterDllNotification: Loaded: {}", utility::narrow(base_dll_name));

if (base_dll_name.find(L"sl.dlss_g.dll") != std::wstring::npos) {
spdlog::info("LdrRegisterDllNotification: Detected DLSS DLL loaded");

D3D12Hook::hook_streamline((HMODULE)NotificationData->Loaded.DllBase);
}
}

if (g_current_game_path && NotificationData->Loaded.FullDllName != nullptr && NotificationData->Loaded.FullDllName->Buffer != nullptr) {
Expand Down
2 changes: 2 additions & 0 deletions src/REFramework.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ class REFramework {
void draw_ui();
void draw_about();

public:
bool hook_d3d11();
bool hook_d3d12();

private:
bool initialize();
bool initialize_game_data();
bool initialize_windows_message_hook();
Expand Down

0 comments on commit f4bc210

Please sign in to comment.