From 08246f5f3828ebbdc87e11d5da89f3b7b09a542e Mon Sep 17 00:00:00 2001 From: Andrea Venuta Date: Wed, 6 Nov 2024 08:39:31 +0100 Subject: [PATCH] Choose command queue associated to swap chain in DirectX 12 hook (#209) Create a builder-like data structure that has to have a IDXGISwapChain set first, and a ID3D12CommandQueue set later, provided it is associated with the previously-set swap chain. The heuristic used is to check the first 512 pointers in the memory pointed to by the swap chain, dereference each of them, and see if any of them matches the command queue that gets passed in. In practice, the command queue is among the first few pointers. A bit of extra care is given in checking for memory readability with VirtualQuery before dereferencing. There is some overhead but it is probably negligible, and only happens the first time the methods are hooked, so it's way less than a frame in total. Some concern might be raised by the fact that now we need to lock a mutex twice per frame, but practically that will happen almost always on the same thread and parking_lot has good performance for that case. Besides, we already do that successfully with the Pipeline. --------- Co-authored-by: Rico --- .cargo/config.toml | 3 ++ .gitignore | 6 --- Cargo.toml | 2 + src/hooks/dx12.rs | 122 +++++++++++++++++++++++++++++++++++++-------- src/util.rs | 114 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 219 insertions(+), 28 deletions(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..597848c9 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,3 @@ +[alias] +c = "xwin clippy --target x86_64-pc-windows-msvc --all" +t = "xwin test --target x86_64-pc-windows-msvc" diff --git a/.gitignore b/.gitignore index 0ca77fac..72a19a86 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,6 @@ Cargo.lock /target **/*.rs.bk -/reference_code -/hudhook.log -/lib/test_sample/*.exe -/lib/test_sample/*.obj -/tests/test_sample.exe .idea -.cargo vkd3d-proton.cache* diff --git a/Cargo.toml b/Cargo.toml index d3aa9f02..1b3508a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,11 +80,13 @@ features = [ "Win32_Graphics_Gdi", "Win32_Graphics_OpenGL", "Win32_Security", + "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_Diagnostics_ToolHelp", "Win32_System_LibraryLoader", "Win32_System_Memory", + "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_UI_Input_KeyboardAndMouse", diff --git a/src/hooks/dx12.rs b/src/hooks/dx12.rs index 0ff80221..3359729b 100644 --- a/src/hooks/dx12.rs +++ b/src/hooks/dx12.rs @@ -7,7 +7,7 @@ use std::sync::OnceLock; use imgui::Context; use once_cell::sync::OnceCell; use parking_lot::Mutex; -use tracing::{error, trace}; +use tracing::{debug, error, trace, warn}; use windows::core::{Error, Interface, Result, HRESULT}; use windows::Win32::Foundation::BOOL; use windows::Win32::Graphics::Direct3D::D3D_FEATURE_LEVEL_11_0; @@ -56,22 +56,101 @@ struct Trampolines { static mut TRAMPOLINES: OnceLock = OnceLock::new(); +enum InitializationContext { + Empty, + WithSwapChain(IDXGISwapChain3), + Complete(IDXGISwapChain3, ID3D12CommandQueue), + Done, +} + +impl InitializationContext { + // Transition to a state where the swap chain is set. Ignore other mutations. + fn insert_swap_chain(&mut self, swap_chain: &IDXGISwapChain3) { + *self = match mem::replace(self, InitializationContext::Empty) { + InitializationContext::Empty => { + InitializationContext::WithSwapChain(swap_chain.clone()) + }, + s => s, + } + } + + // Transition to a complete state if the swap chain is set and the command queue + // is associated with it. + fn insert_command_queue(&mut self, command_queue: &ID3D12CommandQueue) { + *self = match mem::replace(self, InitializationContext::Empty) { + InitializationContext::WithSwapChain(swap_chain) => { + if unsafe { Self::check_command_queue(&swap_chain, command_queue) } { + trace!( + "Found command queue matching swap chain {swap_chain:?} at \ + {command_queue:?}" + ); + InitializationContext::Complete(swap_chain, command_queue.clone()) + } else { + InitializationContext::WithSwapChain(swap_chain) + } + }, + s => s, + } + } + + // Retrieve the values if the context is complete. + fn get(&self) -> Option<(IDXGISwapChain3, ID3D12CommandQueue)> { + if let InitializationContext::Complete(swap_chain, command_queue) = self { + Some((swap_chain.clone(), command_queue.clone())) + } else { + None + } + } + + // Mark the context as done so no further operations are executed on it. + fn done(&mut self) { + if let InitializationContext::Complete(..) = self { + *self = InitializationContext::Done; + } + } + + unsafe fn check_command_queue( + swap_chain: &IDXGISwapChain3, + command_queue: &ID3D12CommandQueue, + ) -> bool { + let swap_chain_ptr = swap_chain.as_raw() as *mut *mut c_void; + let readable_ptrs = util::readable_region(swap_chain_ptr, 512); + + match readable_ptrs.iter().position(|&ptr| ptr == command_queue.as_raw()) { + Some(idx) => { + debug!( + "Found command queue pointer in swap chain struct at offset +0x{:x}", + idx * mem::size_of::(), + ); + true + }, + None => { + warn!( + "Couldn't find command queue pointer in swap chain struct ({} out of 512 \ + pointers were readable)", + readable_ptrs.len() + ); + false + }, + } + } +} + +static INITIALIZATION_CONTEXT: Mutex = + Mutex::new(InitializationContext::Empty); static mut PIPELINE: OnceCell>> = OnceCell::new(); -static mut COMMAND_QUEUE: OnceCell = OnceCell::new(); static mut RENDER_LOOP: OnceCell> = OnceCell::new(); -unsafe fn init_pipeline( - swap_chain: &IDXGISwapChain3, -) -> Result>> { - let Some(command_queue) = COMMAND_QUEUE.get() else { - error!("Command queue not yet initialized"); +unsafe fn init_pipeline() -> Result>> { + let Some((swap_chain, command_queue)) = ({ INITIALIZATION_CONTEXT.lock().get() }) else { + error!("Initialization context incomplete"); return Err(Error::from_hresult(HRESULT(-1))); }; let hwnd = util::try_out_param(|v| swap_chain.GetDesc(v)).map(|desc| desc.OutputWindow)?; let mut ctx = Context::create(); - let engine = D3D12RenderEngine::new(command_queue, &mut ctx)?; + let engine = D3D12RenderEngine::new(&command_queue, &mut ctx)?; let Some(render_loop) = RENDER_LOOP.take() else { error!("Render loop not yet initialized"); @@ -83,12 +162,16 @@ unsafe fn init_pipeline( e })?; + { + INITIALIZATION_CONTEXT.lock().done(); + } + Ok(Mutex::new(pipeline)) } fn render(swap_chain: &IDXGISwapChain3) -> Result<()> { unsafe { - let pipeline = PIPELINE.get_or_try_init(|| init_pipeline(swap_chain))?; + let pipeline = PIPELINE.get_or_try_init(|| init_pipeline())?; let Some(mut pipeline) = pipeline.try_lock() else { error!("Could not lock pipeline"); @@ -111,6 +194,10 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl( sync_interval: u32, flags: u32, ) -> HRESULT { + { + INITIALIZATION_CONTEXT.lock().insert_swap_chain(&swap_chain); + } + let Trampolines { dxgi_swap_chain_present, .. } = TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized"); @@ -148,20 +235,13 @@ unsafe extern "system" fn d3d12_command_queue_execute_command_lists_impl( {command_lists:p}) invoked", ); + { + INITIALIZATION_CONTEXT.lock().insert_command_queue(&command_queue); + } + let Trampolines { d3d12_command_queue_execute_command_lists, .. } = TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized"); - COMMAND_QUEUE - .get_or_try_init(|| unsafe { - let desc = command_queue.GetDesc(); - if desc.Type == D3D12_COMMAND_LIST_TYPE_DIRECT { - Ok(command_queue.clone()) - } else { - Err(()) - } - }) - .ok(); - d3d12_command_queue_execute_command_lists(command_queue, num_command_lists, command_lists); } @@ -309,7 +389,7 @@ impl Hooks for ImguiDx12Hooks { unsafe fn unhook(&mut self) { TRAMPOLINES.take(); PIPELINE.take().map(|p| p.into_inner().take()); - COMMAND_QUEUE.take(); RENDER_LOOP.take(); // should already be null + *INITIALIZATION_CONTEXT.lock() = InitializationContext::Empty; } } diff --git a/src/util.rs b/src/util.rs index 8c023e13..b5313a81 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,7 @@ //! General-purpose utilities. These are used across the [`crate`] but have //! proven useful in client code as well. -use std::ffi::OsString; +use std::ffi::{c_void, OsString}; use std::fmt::Display; use std::mem::ManuallyDrop; use std::os::windows::ffi::OsStringExt; @@ -26,6 +26,11 @@ use windows::Win32::System::LibraryLoader::{ GetModuleFileNameW, GetModuleHandleExA, GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, }; +use windows::Win32::System::Memory::{ + VirtualQuery, MEMORY_BASIC_INFORMATION, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, + PAGE_PROTECTION_FLAGS, PAGE_READONLY, PAGE_READWRITE, +}; +use windows::Win32::System::SystemInformation::{GetSystemInfo, SYSTEM_INFO}; use windows::Win32::System::Threading::{CreateEventExW, WaitForSingleObjectEx, CREATE_EVENT}; use windows::Win32::UI::WindowsAndMessaging::GetClientRect; @@ -280,3 +285,110 @@ impl Fence { Ok(()) } } + +/// Returns a slice of **up to** `limit` elements of type `T` starting at `ptr`. +/// +/// If the memory protection of some pages in this region prevents reading from +/// it, the slice is truncated to the first `N` consecutive readable elements. +/// +/// # Safety +/// +/// - `ptr` must not be a null pointer and must be properly aligned. +/// - Ignoring memory protection, the memory at `ptr` must be valid for at least +/// `limit` elements of type `T` (see [`std::slice::from_raw_parts`]). +pub unsafe fn readable_region(ptr: *const T, limit: usize) -> &'static [T] { + /// Check if the page pointed to by `ptr` is readable. + unsafe fn is_readable( + ptr: *const c_void, + memory_basic_info: &mut MEMORY_BASIC_INFORMATION, + ) -> bool { + // If the page protection has any of these flags set, we can read from it + const PAGE_READABLE: PAGE_PROTECTION_FLAGS = PAGE_PROTECTION_FLAGS( + PAGE_READONLY.0 | PAGE_READWRITE.0 | PAGE_EXECUTE_READ.0 | PAGE_EXECUTE_READWRITE.0, + ); + + (unsafe { + VirtualQuery(Some(ptr), memory_basic_info, size_of::()) + } != 0) + && (memory_basic_info.Protect & PAGE_READABLE).0 != 0 + } + + // This is probably 0x1000 (4096) bytes + let page_size_bytes = { + let mut system_info = SYSTEM_INFO::default(); + unsafe { GetSystemInfo(&mut system_info) }; + system_info.dwPageSize as usize + }; + let page_align_mask = page_size_bytes - 1; + + // Calculate the starting address of the first and last pages that need to be + // readable in order to read `limit` elements of type `T` from `ptr` + let first_page_addr = (ptr as usize) & !page_align_mask; + let last_page_addr = (ptr as usize + (limit * size_of::()) - 1) & !page_align_mask; + + let mut memory_basic_info = MEMORY_BASIC_INFORMATION::default(); + for page_addr in (first_page_addr..=last_page_addr).step_by(page_size_bytes) { + if unsafe { is_readable(page_addr as _, &mut memory_basic_info) } { + continue; + } + + // If this page is not readable, we can read from `ptr` + // up to (not including) the start of this page + // + // Note: `page_addr` can be less than `ptr` if `ptr` is not page-aligned + let num_readable = page_addr.saturating_sub(ptr as usize) / size_of::(); + + // SAFETY: + // - `ptr` is a valid pointer to `limit` elements of type `T` + // - `num_readable` is always less than or equal to `limit` + return std::slice::from_raw_parts(ptr, num_readable); + } + + // SAFETY: + // - `ptr` is a valid pointer to `limit` elements of type `T` and is properly + // aligned + std::slice::from_raw_parts(ptr, limit) +} + +#[cfg(test)] +mod tests { + use windows::Win32::System::Memory::{VirtualAlloc, VirtualProtect, MEM_COMMIT, PAGE_NOACCESS}; + + use super::*; + + #[test] + fn test_readable_region() -> windows::core::Result<()> { + const PAGE_SIZE: usize = 0x1000; + + let region = unsafe { VirtualAlloc(None, 2 * PAGE_SIZE, MEM_COMMIT, PAGE_READWRITE) }; + if region.is_null() { + return Err(windows::core::Error::from_win32()); + } + + // Make the second page unreadable + let mut old_protect = PAGE_PROTECTION_FLAGS::default(); + unsafe { + VirtualProtect( + (region as usize + PAGE_SIZE) as _, + PAGE_SIZE, + PAGE_NOACCESS, + &mut old_protect, + ) + }?; + assert_eq!(old_protect, PAGE_READWRITE); + + let slice = unsafe { readable_region::(region as _, PAGE_SIZE) }; + assert_eq!(slice.len(), PAGE_SIZE); + + let slice = unsafe { readable_region::(region as _, PAGE_SIZE + 1) }; + assert_eq!(slice.len(), PAGE_SIZE); + + let slice = unsafe { readable_region::((region as usize + PAGE_SIZE) as _, 1) }; + assert!(slice.is_empty()); + + let slice = unsafe { readable_region::((region as usize + PAGE_SIZE - 1) as _, 2) }; + assert_eq!(slice.len(), 1); + + Ok(()) + } +}