Skip to content

Commit

Permalink
Check memory protection per page (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
cryeprecision authored Nov 4, 2024
1 parent 55700fd commit 2871d5e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ features = [
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_LibraryLoader",
"Win32_System_Memory",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_UI_Input_KeyboardAndMouse",
Expand Down
46 changes: 19 additions & 27 deletions src/hooks/dx12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::OnceLock;
use imgui::Context;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use tracing::{error, trace};
use tracing::{debug, error, trace, warn};
use windows::core::{Error, Interface, Result, HRESULT};
use windows::Win32::Foundation::BOOL;
use windows::Win32::Graphics::Direct3D::D3D_FEATURE_LEVEL_11_0;
Expand All @@ -24,7 +24,6 @@ use windows::Win32::Graphics::Dxgi::{
DXGI_SWAP_CHAIN_FLAG_ALLOW_MODE_SWITCH, DXGI_SWAP_EFFECT_FLIP_DISCARD,
DXGI_USAGE_RENDER_TARGET_OUTPUT,
};
use windows::Win32::System::Memory::{VirtualQuery, MEMORY_BASIC_INFORMATION, PAGE_READWRITE};

use super::DummyHwnd;
use crate::mh::MhHook;
Expand Down Expand Up @@ -115,32 +114,25 @@ impl InitializationContext {
command_queue: &ID3D12CommandQueue,
) -> bool {
let swap_chain_ptr = swap_chain.as_raw() as *mut *mut c_void;
let mut mbi = MEMORY_BASIC_INFORMATION::default();

for i in 0..512 {
let command_queue_ptr = swap_chain_ptr.add(i);
if VirtualQuery(Some(command_queue_ptr as *const c_void), &mut mbi, size_of_val(&mbi))
== 0
{
continue;
}

trace!(
"Offset {i} pointer {command_queue_ptr:p} against {:p} has protect {:?}",
command_queue.as_raw(),
mbi.Protect
);

if !mbi.Protect.contains(PAGE_READWRITE) {
continue;
}

if *command_queue_ptr == command_queue.as_raw() {
return true;
}
let readable_ptrs = util::readable_region(swap_chain_ptr, 512);

match readable_ptrs.iter().position(|&ptr| ptr == command_queue.as_raw()) {
Some(idx) => {
debug!(
"Found command queue pointer in swap chain struct at offset +0x{:x}",
idx * mem::size_of::<usize>(),
);
true
},
None => {
warn!(
"Couldn't find command queue pointer in swap chain struct ({} out of 512 \
pointers were readable)",
readable_ptrs.len()
);
false
},
}

false
}
}

Expand Down
114 changes: 113 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! General-purpose utilities. These are used across the [`crate`] but have
//! proven useful in client code as well.

use std::ffi::OsString;
use std::ffi::{c_void, OsString};
use std::fmt::Display;
use std::mem::ManuallyDrop;
use std::os::windows::ffi::OsStringExt;
Expand All @@ -26,6 +26,11 @@ use windows::Win32::System::LibraryLoader::{
GetModuleFileNameW, GetModuleHandleExA, GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
};
use windows::Win32::System::Memory::{
VirtualQuery, MEMORY_BASIC_INFORMATION, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE,
PAGE_PROTECTION_FLAGS, PAGE_READONLY, PAGE_READWRITE,
};
use windows::Win32::System::SystemInformation::{GetSystemInfo, SYSTEM_INFO};
use windows::Win32::System::Threading::{CreateEventExW, WaitForSingleObjectEx, CREATE_EVENT};
use windows::Win32::UI::WindowsAndMessaging::GetClientRect;

Expand Down Expand Up @@ -280,3 +285,110 @@ impl Fence {
Ok(())
}
}

/// Returns a slice of **up to** `limit` elements of type `T` starting at `ptr`.
///
/// If the memory protection of some pages in this region prevents reading from
/// it, the slice is truncated to the first `N` consecutive readable elements.
///
/// # Safety
///
/// - `ptr` must not be a null pointer and must be properly aligned.
/// - Ignoring memory protection, the memory at `ptr` must be valid for at least
/// `limit` elements of type `T` (see [`std::slice::from_raw_parts`]).
pub unsafe fn readable_region<T>(ptr: *const T, limit: usize) -> &'static [T] {
/// Check if the page pointed to by `ptr` is readable.
unsafe fn is_readable(
ptr: *const c_void,
memory_basic_info: &mut MEMORY_BASIC_INFORMATION,
) -> bool {
// If the page protection has any of these flags set, we can read from it
const PAGE_READABLE: PAGE_PROTECTION_FLAGS = PAGE_PROTECTION_FLAGS(
PAGE_READONLY.0 | PAGE_READWRITE.0 | PAGE_EXECUTE_READ.0 | PAGE_EXECUTE_READWRITE.0,
);

(unsafe {
VirtualQuery(Some(ptr), memory_basic_info, size_of::<MEMORY_BASIC_INFORMATION>())
} != 0)
&& (memory_basic_info.Protect & PAGE_READABLE).0 != 0
}

// This is probably 0x1000 (4096) bytes
let page_size_bytes = {
let mut system_info = SYSTEM_INFO::default();
unsafe { GetSystemInfo(&mut system_info) };
system_info.dwPageSize as usize
};
let page_align_mask = page_size_bytes - 1;

// Calculate the starting address of the first and last pages that need to be
// readable in order to read `limit` elements of type `T` from `ptr`
let first_page_addr = (ptr as usize) & !page_align_mask;
let last_page_addr = (ptr as usize + (limit * size_of::<T>()) - 1) & !page_align_mask;

let mut memory_basic_info = MEMORY_BASIC_INFORMATION::default();
for page_addr in (first_page_addr..=last_page_addr).step_by(page_size_bytes) {
if unsafe { is_readable(page_addr as _, &mut memory_basic_info) } {
continue;
}

// If this page is not readable, we can read from `ptr`
// up to (not including) the start of this page
//
// Note: `page_addr` can be less than `ptr` if `ptr` is not page-aligned
let num_readable = page_addr.saturating_sub(ptr as usize) / size_of::<T>();

// SAFETY:
// - `ptr` is a valid pointer to `limit` elements of type `T`
// - `num_readable` is always less than or equal to `limit`
return std::slice::from_raw_parts(ptr, num_readable);
}

// SAFETY:
// - `ptr` is a valid pointer to `limit` elements of type `T` and is properly
// aligned
std::slice::from_raw_parts(ptr, limit)
}

#[cfg(test)]
mod tests {
use windows::Win32::System::Memory::{VirtualAlloc, VirtualProtect, MEM_COMMIT, PAGE_NOACCESS};

use super::*;

#[test]
fn test_readable_region() -> windows::core::Result<()> {
const PAGE_SIZE: usize = 0x1000;

let region = unsafe { VirtualAlloc(None, 2 * PAGE_SIZE, MEM_COMMIT, PAGE_READWRITE) };
if region.is_null() {
return Err(windows::core::Error::from_win32());
}

// Make the second page unreadable
let mut old_protect = PAGE_PROTECTION_FLAGS::default();
unsafe {
VirtualProtect(
(region as usize + PAGE_SIZE) as _,
PAGE_SIZE,
PAGE_NOACCESS,
&mut old_protect,
)
}?;
assert_eq!(old_protect, PAGE_READWRITE);

let slice = unsafe { readable_region::<u8>(region as _, PAGE_SIZE) };
assert_eq!(slice.len(), PAGE_SIZE);

let slice = unsafe { readable_region::<u8>(region as _, PAGE_SIZE + 1) };
assert_eq!(slice.len(), PAGE_SIZE);

let slice = unsafe { readable_region::<u8>((region as usize + PAGE_SIZE) as _, 1) };
assert!(slice.is_empty());

let slice = unsafe { readable_region::<u8>((region as usize + PAGE_SIZE - 1) as _, 2) };
assert_eq!(slice.len(), 1);

Ok(())
}
}

0 comments on commit 2871d5e

Please sign in to comment.