Skip to content

Commit

Permalink
Partial initialization context
Browse files Browse the repository at this point in the history
Create a builder-like data structure that has to have a IDXGISwapChain
set first, and a ID3D12CommandQueue set later, provided it is associated
with the previously-set swap chain. The heuristic used is to check
the first 512 pointers in the memory pointed to by the swap chain,
dereference each of them, and see if any of them matches the command queue
that gets passed in. In practice, the command queue is among the first
few pointers. A bit of extra care is given in checking for memory
readability with VirtualQuery before dereferencing. There is some
overhead but it is probably negligible, and only happens the first time
the methods are hooked, so it's way less than a frame in total.

Some concern might be raised by the fact that now we need to lock a
mutex twice per frame, but practically that will happen almost always
on the same thread and parking_lot has good performance for that case.
Besides, we already do that successfully with the Pipeline.
  • Loading branch information
veeenu committed Oct 30, 2024
1 parent 9d33710 commit 55700fd
Showing 1 changed file with 111 additions and 38 deletions.
149 changes: 111 additions & 38 deletions src/hooks/dx12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@

use std::ffi::c_void;
use std::mem;
use std::ptr::null_mut;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::OnceLock;

use imgui::Context;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use tracing::{error, info, trace};
use windows::core::IUnknown;
use tracing::{error, trace};
use windows::core::{Error, Interface, Result, HRESULT};
use windows::Win32::Foundation::{BOOL, E_NOINTERFACE};
use windows::Win32::Foundation::BOOL;
use windows::Win32::Graphics::Direct3D::D3D_FEATURE_LEVEL_11_0;
use windows::Win32::Graphics::Direct3D12::{
D3D12CreateDevice, ID3D12CommandList, ID3D12CommandQueue, ID3D12Device, ID3D12Resource,
Expand All @@ -27,7 +24,7 @@ use windows::Win32::Graphics::Dxgi::{
DXGI_SWAP_CHAIN_FLAG_ALLOW_MODE_SWITCH, DXGI_SWAP_EFFECT_FLIP_DISCARD,
DXGI_USAGE_RENDER_TARGET_OUTPUT,
};
use windows::Win32::System::Memory::IsBadReadPtr;
use windows::Win32::System::Memory::{VirtualQuery, MEMORY_BASIC_INFORMATION, PAGE_READWRITE};

use super::DummyHwnd;
use crate::mh::MhHook;
Expand Down Expand Up @@ -60,22 +57,103 @@ struct Trampolines {

static mut TRAMPOLINES: OnceLock<Trampolines> = OnceLock::new();

enum InitializationContext {
Empty,
WithSwapChain(IDXGISwapChain3),
Complete(IDXGISwapChain3, ID3D12CommandQueue),
Done,
}

impl InitializationContext {
// Transition to a state where the swap chain is set. Ignore other mutations.
fn insert_swap_chain(&mut self, swap_chain: &IDXGISwapChain3) {
*self = match mem::replace(self, InitializationContext::Empty) {
InitializationContext::Empty => {
InitializationContext::WithSwapChain(swap_chain.clone())
},
s => s,
}
}

// Transition to a complete state if the swap chain is set and the command queue
// is associated with it.
fn insert_command_queue(&mut self, command_queue: &ID3D12CommandQueue) {
*self = match mem::replace(self, InitializationContext::Empty) {
InitializationContext::WithSwapChain(swap_chain) => {
if unsafe { Self::check_command_queue(&swap_chain, command_queue) } {
trace!(
"Found command queue matching swap chain {swap_chain:?} at \
{command_queue:?}"
);
InitializationContext::Complete(swap_chain, command_queue.clone())
} else {
InitializationContext::WithSwapChain(swap_chain)
}
},
s => s,
}
}

// Retrieve the values if the context is complete.
fn get(&self) -> Option<(IDXGISwapChain3, ID3D12CommandQueue)> {
if let InitializationContext::Complete(swap_chain, command_queue) = self {
Some((swap_chain.clone(), command_queue.clone()))
} else {
None
}
}

// Mark the context as done so no further operations are executed on it.
fn done(&mut self) {
if let InitializationContext::Complete(..) = self {
*self = InitializationContext::Done;
}
}

unsafe fn check_command_queue(
swap_chain: &IDXGISwapChain3,
command_queue: &ID3D12CommandQueue,
) -> bool {
let swap_chain_ptr = swap_chain.as_raw() as *mut *mut c_void;
let mut mbi = MEMORY_BASIC_INFORMATION::default();

for i in 0..512 {
let command_queue_ptr = swap_chain_ptr.add(i);
if VirtualQuery(Some(command_queue_ptr as *const c_void), &mut mbi, size_of_val(&mbi))
== 0
{
continue;
}

trace!(
"Offset {i} pointer {command_queue_ptr:p} against {:p} has protect {:?}",
command_queue.as_raw(),
mbi.Protect
);

if !mbi.Protect.contains(PAGE_READWRITE) {
continue;
}

if *command_queue_ptr == command_queue.as_raw() {
return true;
}
}

false
}
}

static INITIALIZATION_CONTEXT: Mutex<InitializationContext> =
Mutex::new(InitializationContext::Empty);
static mut PIPELINE: OnceCell<Mutex<Pipeline<D3D12RenderEngine>>> = OnceCell::new();
static mut RENDER_LOOP: OnceCell<Box<dyn ImguiRenderLoop + Send + Sync>> = OnceCell::new();
static COMMAND_QUEUE_OFFSET: AtomicUsize = AtomicUsize::new(0);

unsafe fn init_pipeline(
swap_chain: &IDXGISwapChain3,
) -> Result<Mutex<Pipeline<D3D12RenderEngine>>> {
let command_queue_offset = COMMAND_QUEUE_OFFSET.load(Ordering::SeqCst);
if command_queue_offset == 0 {
error!("Could not find command queue offset into IDXGISwapChain3");
return Err(Error::from_hresult(HRESULT(-1)));
}

let command_queue = ID3D12CommandQueue::from_raw(
*(swap_chain.as_raw() as *mut *mut c_void).add(command_queue_offset),
);
unsafe fn init_pipeline() -> Result<Mutex<Pipeline<D3D12RenderEngine>>> {
let Some((swap_chain, command_queue)) = ({ INITIALIZATION_CONTEXT.lock().get() }) else {
error!("Initialization context incomplete");
return Err(Error::from_hresult(HRESULT(-1)));
};

let hwnd = util::try_out_param(|v| swap_chain.GetDesc(v)).map(|desc| desc.OutputWindow)?;

Expand All @@ -92,12 +170,16 @@ unsafe fn init_pipeline(
e
})?;

{
INITIALIZATION_CONTEXT.lock().done();
}

Ok(Mutex::new(pipeline))
}

fn render(swap_chain: &IDXGISwapChain3) -> Result<()> {
unsafe {
let pipeline = PIPELINE.get_or_try_init(|| init_pipeline(swap_chain))?;
let pipeline = PIPELINE.get_or_try_init(|| init_pipeline())?;

let Some(mut pipeline) = pipeline.try_lock() else {
error!("Could not lock pipeline");
Expand All @@ -120,6 +202,10 @@ unsafe extern "system" fn dxgi_swap_chain_present_impl(
sync_interval: u32,
flags: u32,
) -> HRESULT {
{
INITIALIZATION_CONTEXT.lock().insert_swap_chain(&swap_chain);
}

let Trampolines { dxgi_swap_chain_present, .. } =
TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized");

Expand Down Expand Up @@ -157,6 +243,10 @@ unsafe extern "system" fn d3d12_command_queue_execute_command_lists_impl(
{command_lists:p}) invoked",
);

{
INITIALIZATION_CONTEXT.lock().insert_command_queue(&command_queue);
}

let Trampolines { d3d12_command_queue_execute_command_lists, .. } =
TRAMPOLINES.get().expect("DirectX 12 trampolines uninitialized");

Expand Down Expand Up @@ -219,24 +309,6 @@ fn get_target_addrs() -> (
},
};

let swap_chain_ptr = swap_chain.as_raw() as *mut *mut c_void;
let command_queue_ptr = command_queue.as_raw();
let command_queue_offset = (0..512).find(|&i| unsafe {
let ptr = swap_chain_ptr.add(i as usize);
trace!("Trying command queue offset {ptr:p} as {:p} == {command_queue_ptr:p}", *ptr);
if command_queue_ptr == *ptr {
trace!("Found command queue ptr at offset {i}");
true
} else {
false
}
});

match command_queue_offset {
Some(offset) => COMMAND_QUEUE_OFFSET.store(offset, Ordering::SeqCst),
None => panic!("Could not find command queue offset in IDXGISwapChain3"),
}

let present_ptr: DXGISwapChainPresentType =
unsafe { mem::transmute(swap_chain.vtable().Present) };
let resize_buffers_ptr: DXGISwapChainResizeBuffersType =
Expand Down Expand Up @@ -326,5 +398,6 @@ impl Hooks for ImguiDx12Hooks {
TRAMPOLINES.take();
PIPELINE.take().map(|p| p.into_inner().take());
RENDER_LOOP.take(); // should already be null
*INITIALIZATION_CONTEXT.lock() = InitializationContext::Empty;
}
}

0 comments on commit 55700fd

Please sign in to comment.