Skip to content

Commit

Permalink
Choose command queue associated to swap chain in DirectX 12 hook (#209)
Browse files Browse the repository at this point in the history
Create a builder-like data structure that has to have a IDXGISwapChain
set first, and a ID3D12CommandQueue set later, provided it is associated
with the previously-set swap chain. The heuristic used is to check
the first 512 pointers in the memory pointed to by the swap chain,
dereference each of them, and see if any of them matches the command queue
that gets passed in. In practice, the command queue is among the first
few pointers. A bit of extra care is given in checking for memory
readability with VirtualQuery before dereferencing. There is some
overhead but it is probably negligible, and only happens the first time
the methods are hooked, so it's way less than a frame in total.

Some concern might be raised by the fact that now we need to lock a
mutex twice per frame, but practically that will happen almost always
on the same thread and parking_lot has good performance for that case.
Besides, we already do that successfully with the Pipeline.

---------

Co-authored-by: Rico <[email protected]>
  • Loading branch information
veeenu and cryeprecision authored Nov 6, 2024
1 parent 216c680 commit 08246f5
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 28 deletions.
3 changes: 3 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[alias]
c = "xwin clippy --target x86_64-pc-windows-msvc --all"
t = "xwin test --target x86_64-pc-windows-msvc"
6 changes: 0 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@
Cargo.lock
/target
**/*.rs.bk
/reference_code
/hudhook.log
/lib/test_sample/*.exe
/lib/test_sample/*.obj
/tests/test_sample.exe

.idea
.cargo
vkd3d-proton.cache*
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ features = [
"Win32_Graphics_Gdi",
"Win32_Graphics_OpenGL",
"Win32_Security",
"Win32_System_Com",
"Win32_System_Console",
"Win32_System_Diagnostics_Debug",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_LibraryLoader",
"Win32_System_Memory",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_UI_Input_KeyboardAndMouse",
Expand Down
122 changes: 101 additions & 21 deletions src/hooks/dx12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::OnceLock;
use imgui::Context;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use tracing::{error, trace};
use tracing::{debug, error, trace, warn};
use windows::core::{Error, Interface, Result, HRESULT};
use windows::Win32::Foundation::BOOL;
use windows::Win32::Graphics::Direct3D::D3D_FEATURE_LEVEL_11_0;
Expand Down Expand Up @@ -56,22 +56,101 @@ struct Trampolines {

static mut TRAMPOLINES: OnceLock<Trampolines> = OnceLock::new();

enum InitializationContext {
Empty,
WithSwapChain(IDXGISwapChain3),
Complete(IDXGISwapChain3, ID3D12CommandQueue),
Done,
}

impl InitializationContext {
// Transition to a state where the swap chain is set. Ignore other mutations.
fn insert_swap_chain(&mut self, swap_chain: &IDXGISwapChain3) {
*self = match mem::replace(self, InitializationContext::Empty) {
InitializationContext::Empty => {
InitializationContext::WithSwapChain(swap_chain.clone())
},
s => s,
}
}

// Transition to a complete state if the swap chain is set and the command queue
// is associated with it.
fn insert_command_queue(&mut self, command_queue: &ID3D12CommandQueue) {
*self = match mem::replace(self, InitializationContext::Empty) {
InitializationContext::WithSwapChain(swap_chain) => {
if unsafe { Self::check_command_queue(&swap_chain, command_queue) } {
trace!(
"Found command queue matching swap chain {swap_chain:?} at \
{command_queue:?}"
);
InitializationContext::Complete(swap_chain, command_queue.clone())
} else {
InitializationContext::WithSwapChain(swap_chain)
}
},
s => s,
}
}

// Retrieve the values if the context is complete.
fn get(&self) -> Option<(IDXGISwapChain3, ID3D12CommandQueue)> {
if let InitializationContext::Complete(swap_chain, command_queue) = self {
Some((swap_chain.clone(), command_queue.clone()))
} else {
None
}
}

// Mark the context as done so no further operations are executed on it.
fn done(&mut self) {
if let InitializationContext::Complete(..) = self {
*self = InitializationContext::Done;
}
}

unsafe fn check_command_queue(
swap_chain: &IDXGISwapChain3,
command_queue: &ID3D12CommandQueue,
) -> bool {
let swap_chain_ptr = swap_chain.as_raw() as *mut *mut c_void;
let readable_ptrs = util::readable_region(swap_chain_ptr, 512);

match readable_ptrs.iter().position(|&ptr| ptr == command_queue.as_raw()) {
Some(idx) => {
debug!(
"Found command queue pointer in swap chain struct at offset +0x{:x}",
idx * mem::size_of::<usize>(),
);
true
},
None => {
warn!(
"Couldn't find command queue pointer in swap chain struct ({} out of 512 \
pointers were readable)",
readable_ptrs.len()
);
false
},
}
}
}

static INITIALIZATION_CONTEXT: Mutex<InitializationContext> =
Mutex::new(InitializationContext::Empty);
static mut PIPELINE: OnceCell<Mutex<Pipeline<D3D12RenderEngine>>> = OnceCell::new();
static mut COMMAND_QUEUE: OnceCell<ID3D12CommandQueue> = OnceCell::new();
static mut RENDER_LOOP: OnceCell<Box<dyn ImguiRenderLoop + Send + Sync>> = OnceCell::new();

unsafe fn init_pipeline(
swap_chain: &IDXGISwapChain3,
) -> Result<Mutex<Pipeline<D3D12RenderEngine>>> {
let Some(command_queue) = COMMAND_QUEUE.get() else {
error!("Command queue not yet initialized");
unsafe fn init_pipeline() -> Result<Mutex<Pipeline<D3D12RenderEngine>>> {
let Some((swap_chain, command_queue)) = ({ INITIALIZATION_CONTEXT.lock().get() }) else {
error!("Initialization context incomplete");
return Err(Error::from_hresult(HRESULT(-1)));
};

let hwnd = util::try_out_param(|v| swap_chain.GetDesc(v)).map(|desc| desc.OutputWindow)?;

let mut ctx = Context::create();
let engine = D3D12RenderEngine::new(command_queue, &mut ctx)?;
let engine = D3D12RenderEngine::new(&command_queue, &mut ctx)?;

let Some(render_loop) = RENDER_LOOP.take() else {
error!("Render loop not yet initialized");
Expand All @@ -83,12 +162,16 @@ unsafe fn init_pipeline(
e
})?;

{
INITIALIZATION_CONTEXT.lock().done();
}

Ok(Mutex::new(pipeline))
}

fn render(swap_chain: &IDXGISwapChain3) -> Result<()> {
unsafe {
let pipeline = PIPELINE.get_or_try_init(|| init_pipeline(swap_chain))?;
let pipeline = PIPELINE.get_or_try_init(|| init_pipeline())?;

let Some(mut pipeline) = pipeline.try_lock() else {
error!("Could not lock pipeline");
Expand All @@ -111,6 +194,10 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl(
sync_interval: u32,
flags: u32,
) -> HRESULT {
{
INITIALIZATION_CONTEXT.lock().insert_swap_chain(&swap_chain);
}

let Trampolines { dxgi_swap_chain_present, .. } =
TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized");

Expand Down Expand Up @@ -148,20 +235,13 @@ unsafe extern "system" fn d3d12_command_queue_execute_command_lists_impl(
{command_lists:p}) invoked",
);

{
INITIALIZATION_CONTEXT.lock().insert_command_queue(&command_queue);
}

let Trampolines { d3d12_command_queue_execute_command_lists, .. } =
TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized");

COMMAND_QUEUE
.get_or_try_init(|| unsafe {
let desc = command_queue.GetDesc();
if desc.Type == D3D12_COMMAND_LIST_TYPE_DIRECT {
Ok(command_queue.clone())
} else {
Err(())
}
})
.ok();

d3d12_command_queue_execute_command_lists(command_queue, num_command_lists, command_lists);
}

Expand Down Expand Up @@ -309,7 +389,7 @@ impl Hooks for ImguiDx12Hooks {
unsafe fn unhook(&mut self) {
TRAMPOLINES.take();
PIPELINE.take().map(|p| p.into_inner().take());
COMMAND_QUEUE.take();
RENDER_LOOP.take(); // should already be null
*INITIALIZATION_CONTEXT.lock() = InitializationContext::Empty;
}
}
114 changes: 113 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! General-purpose utilities. These are used across the [`crate`] but have
//! proven useful in client code as well.

use std::ffi::OsString;
use std::ffi::{c_void, OsString};
use std::fmt::Display;
use std::mem::ManuallyDrop;
use std::os::windows::ffi::OsStringExt;
Expand All @@ -26,6 +26,11 @@ use windows::Win32::System::LibraryLoader::{
GetModuleFileNameW, GetModuleHandleExA, GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
};
use windows::Win32::System::Memory::{
VirtualQuery, MEMORY_BASIC_INFORMATION, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE,
PAGE_PROTECTION_FLAGS, PAGE_READONLY, PAGE_READWRITE,
};
use windows::Win32::System::SystemInformation::{GetSystemInfo, SYSTEM_INFO};
use windows::Win32::System::Threading::{CreateEventExW, WaitForSingleObjectEx, CREATE_EVENT};
use windows::Win32::UI::WindowsAndMessaging::GetClientRect;

Expand Down Expand Up @@ -280,3 +285,110 @@ impl Fence {
Ok(())
}
}

/// Returns a slice of **up to** `limit` elements of type `T` starting at `ptr`.
///
/// If the memory protection of some pages in this region prevents reading from
/// it, the slice is truncated to the first `N` consecutive readable elements.
///
/// # Safety
///
/// - `ptr` must not be a null pointer and must be properly aligned.
/// - Ignoring memory protection, the memory at `ptr` must be valid for at least
/// `limit` elements of type `T` (see [`std::slice::from_raw_parts`]).
pub unsafe fn readable_region<T>(ptr: *const T, limit: usize) -> &'static [T] {
/// Check if the page pointed to by `ptr` is readable.
unsafe fn is_readable(
ptr: *const c_void,
memory_basic_info: &mut MEMORY_BASIC_INFORMATION,
) -> bool {
// If the page protection has any of these flags set, we can read from it
const PAGE_READABLE: PAGE_PROTECTION_FLAGS = PAGE_PROTECTION_FLAGS(
PAGE_READONLY.0 | PAGE_READWRITE.0 | PAGE_EXECUTE_READ.0 | PAGE_EXECUTE_READWRITE.0,
);

(unsafe {
VirtualQuery(Some(ptr), memory_basic_info, size_of::<MEMORY_BASIC_INFORMATION>())
} != 0)
&& (memory_basic_info.Protect & PAGE_READABLE).0 != 0
}

// This is probably 0x1000 (4096) bytes
let page_size_bytes = {
let mut system_info = SYSTEM_INFO::default();
unsafe { GetSystemInfo(&mut system_info) };
system_info.dwPageSize as usize
};
let page_align_mask = page_size_bytes - 1;

// Calculate the starting address of the first and last pages that need to be
// readable in order to read `limit` elements of type `T` from `ptr`
let first_page_addr = (ptr as usize) & !page_align_mask;
let last_page_addr = (ptr as usize + (limit * size_of::<T>()) - 1) & !page_align_mask;

let mut memory_basic_info = MEMORY_BASIC_INFORMATION::default();
for page_addr in (first_page_addr..=last_page_addr).step_by(page_size_bytes) {
if unsafe { is_readable(page_addr as _, &mut memory_basic_info) } {
continue;
}

// If this page is not readable, we can read from `ptr`
// up to (not including) the start of this page
//
// Note: `page_addr` can be less than `ptr` if `ptr` is not page-aligned
let num_readable = page_addr.saturating_sub(ptr as usize) / size_of::<T>();

// SAFETY:
// - `ptr` is a valid pointer to `limit` elements of type `T`
// - `num_readable` is always less than or equal to `limit`
return std::slice::from_raw_parts(ptr, num_readable);
}

// SAFETY:
// - `ptr` is a valid pointer to `limit` elements of type `T` and is properly
// aligned
std::slice::from_raw_parts(ptr, limit)
}

#[cfg(test)]
mod tests {
use windows::Win32::System::Memory::{VirtualAlloc, VirtualProtect, MEM_COMMIT, PAGE_NOACCESS};

use super::*;

#[test]
fn test_readable_region() -> windows::core::Result<()> {
const PAGE_SIZE: usize = 0x1000;

let region = unsafe { VirtualAlloc(None, 2 * PAGE_SIZE, MEM_COMMIT, PAGE_READWRITE) };
if region.is_null() {
return Err(windows::core::Error::from_win32());
}

// Make the second page unreadable
let mut old_protect = PAGE_PROTECTION_FLAGS::default();
unsafe {
VirtualProtect(
(region as usize + PAGE_SIZE) as _,
PAGE_SIZE,
PAGE_NOACCESS,
&mut old_protect,
)
}?;
assert_eq!(old_protect, PAGE_READWRITE);

let slice = unsafe { readable_region::<u8>(region as _, PAGE_SIZE) };
assert_eq!(slice.len(), PAGE_SIZE);

let slice = unsafe { readable_region::<u8>(region as _, PAGE_SIZE + 1) };
assert_eq!(slice.len(), PAGE_SIZE);

let slice = unsafe { readable_region::<u8>((region as usize + PAGE_SIZE) as _, 1) };
assert!(slice.is_empty());

let slice = unsafe { readable_region::<u8>((region as usize + PAGE_SIZE - 1) as _, 2) };
assert_eq!(slice.len(), 1);

Ok(())
}
}

0 comments on commit 08246f5

Please sign in to comment.