From 95833f3d602eb0e77c0010e6702ab9a17f6f8e6d Mon Sep 17 00:00:00 2001 From: Arlie Davis Date: Tue, 14 May 2024 22:49:46 -0700 Subject: [PATCH] more features for ComObject --- crates/libs/core/src/com_object.rs | 136 ++++++++++++++++----- crates/libs/core/src/imp/mod.rs | 4 +- crates/libs/core/src/interface.rs | 91 +++++++++++++- crates/libs/core/src/ref.rs | 3 +- crates/libs/implement/src/lib.rs | 6 +- crates/tests/implement/tests/com_object.rs | 109 +++++++++++++++-- 6 files changed, 298 insertions(+), 51 deletions(-) diff --git a/crates/libs/core/src/com_object.rs b/crates/libs/core/src/com_object.rs index cbaf496e2b..fc60fe6e34 100644 --- a/crates/libs/core/src/com_object.rs +++ b/crates/libs/core/src/com_object.rs @@ -1,6 +1,8 @@ -use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, Ref}; +use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef}; +use core::ffi::c_void; use core::mem::ManuallyDrop; use core::ptr::NonNull; +use std::borrow::Borrow; // This is implemented on user types that are marked with #[implement]. #[allow(missing_docs)] @@ -13,10 +15,10 @@ pub unsafe trait ComImpl { /// This trait is implemented by ComObject implementation obejcts (e.g. `MyApp_Impl`). pub trait ComObjectInterface { /// Gets a borrowed interface on the ComObject. - fn get_interface(&self) -> Ref<'_, I>; + fn get_interface(&self) -> InterfaceRef<'_, I>; } -/// A counted pointed to a type that implements COM interfaces, where the object has been +/// A counted pointer to a type that implements COM interfaces, where the object has been /// placed in the heap (boxed). /// /// This type exists so that you can place an object into the heap and query for COM interfaces, @@ -25,26 +27,11 @@ pub struct ComObject { ptr: NonNull, } -impl Default for ComObject { - fn default() -> Self { - Self::new(T::default()) - } -} - -impl Drop for ComObject { - fn drop(&mut self) { - unsafe { - T::Impl::Release(self.ptr.as_ptr()); - } - } -} - impl ComObject { /// Allocates a heap cell (box) and moves `obj` into it. Returns a counted pointer to `obj`. pub fn new(value: T) -> Self { unsafe { let box_ = T::Impl::new_box(value); - Self { ptr: NonNull::new_unchecked(Box::into_raw(box_)) } } } @@ -65,6 +52,12 @@ impl ComObject { unsafe { self.ptr.as_ref() } } + // Note that we _do not_ provide a way to get a mutable reference to the outer box. + // It's ok to return &mut T, but not &mut T::Impl. That would allow someone to replace the + // contents of the entire object (box and reference count), which could lead to UB. + // This could maybe be solved by returning Pin<&mut T::Impl>, but that requires some + // additional thinking. + /// Gets a mutable reference to the object stored in the box, if the reference count /// is exactly 1. If there are multiple references to this object then this returns `None`. #[inline(always)] @@ -80,8 +73,8 @@ impl ComObject { } /// If this object has only a single object reference (i.e. this `ComObject` is the only - /// reference to the heap allocation), then this method will destroy the `ComObject` and will - /// return the inner implementation object, wrapped in `Ok`. + /// reference to the heap allocation), then this method will extract the inner `T` + /// (and return it in an `Ok`) and then free the heap allocation. /// /// If there is more than one reference to this object, then this returns `Err(self)`. #[inline(always)] @@ -108,28 +101,50 @@ impl ComObject { } } - /// Gets a reference to an interface that is implemented by this ComObject. + /// Gets a borrowed reference to an interface that is implemented by this ComObject. /// /// The returned reference does not have an additional reference count. - /// You can AddRef it by calling clone(). - pub fn borrow_interface(&self) -> Ref<'_, I> + /// You can AddRef it by calling to_owned(). + pub fn as_interface(&self) -> InterfaceRef<'_, I> where T::Impl: ComObjectInterface, { self.get_box().get_interface() } - /// Gets a counted reference to an interface that is implemented by this ComObject. - pub fn get_interface(&self) -> I + /// Gets an owned (counted) reference to an interface that is implemented by this ComObject. + pub fn to_interface(&self) -> I where - I: Clone, T::Impl: ComObjectInterface, { - let interface_ref: Ref<'_, I> = self.get_box().get_interface(); - let interface_inner: &I = interface_ref.ok().unwrap(); - interface_inner.clone() + self.as_interface::().to_owned() } + /// Converts this `ComObject` into an interface that it implements. + pub fn into_interface(self) -> I + where + T::Impl: ComObjectInterface, + { + unsafe { + let raw: *mut c_void = self.get_box().get_interface().as_raw(); + core::mem::forget(self); + I::from_raw(raw) + } + } +} + +impl Default for ComObject { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl Drop for ComObject { + fn drop(&mut self) { + unsafe { + T::Impl::Release(self.ptr.as_ptr()); + } + } } impl Clone for ComObject { @@ -152,7 +167,7 @@ where } } -impl> core::ops::Deref for ComObject { +impl core::ops::Deref for ComObject { type Target = T::Impl; #[inline(always)] @@ -161,8 +176,69 @@ impl> core::ops::Deref for ComObject { } } +// There is no DerefMut implementation because we cannot statically guarantee +// that the reference count is 1, which is a requirement for getting exclusive +// access to the contents of the object. Use get_mut() for dynamically-checked +// exclusive access. + impl From for ComObject { fn from(value: T) -> ComObject { ComObject::new(value) } } + +// Delegate hashing, if implemented. +impl core::hash::Hash for ComObject { + fn hash(&self, state: &mut H) { + self.get().hash(state); + } +} + +// If T is Send (or Sync) then the ComObject is also Send (or Sync). +// Since the actual object storage is in the heap, the object is never moved. +unsafe impl Send for ComObject {} +unsafe impl Sync for ComObject {} + +impl PartialEq for ComObject { + fn eq(&self, other: &ComObject) -> bool { + let inner_self: &T = self.get(); + let other_self: &T = other.get(); + inner_self == other_self + } +} + +impl Eq for ComObject {} + +impl PartialOrd for ComObject { + fn partial_cmp(&self, other: &Self) -> Option { + let inner_self: &T = self.get(); + let other_self: &T = other.get(); + ::partial_cmp(inner_self, other_self) + } +} + +impl Ord for ComObject { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + let inner_self: &T = self.get(); + let other_self: &T = other.get(); + ::cmp(inner_self, other_self) + } +} + +impl core::fmt::Debug for ComObject { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + ::fmt(self.get(), f) + } +} + +impl core::fmt::Display for ComObject { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + ::fmt(self.get(), f) + } +} + +impl Borrow for ComObject { + fn borrow(&self) -> &T { + self.get() + } +} diff --git a/crates/libs/core/src/imp/mod.rs b/crates/libs/core/src/imp/mod.rs index b0186cf248..06f756f70f 100644 --- a/crates/libs/core/src/imp/mod.rs +++ b/crates/libs/core/src/imp/mod.rs @@ -104,7 +104,7 @@ pub use define_interface; pub use std::boxed::Box; #[doc(hidden)] -pub const E_POINTER: crate::HRESULT = crate::HRESULT(-2147467261); +pub const E_POINTER: crate::HRESULT = crate::HRESULT(0x80004003u32 as i32); #[doc(hidden)] -pub const E_NOINTERFACE: crate::HRESULT = crate::HRESULT(-2147467262); +pub const E_NOINTERFACE: crate::HRESULT = crate::HRESULT(0x80004002u32 as i32); diff --git a/crates/libs/core/src/interface.rs b/crates/libs/core/src/interface.rs index 1ef3586e70..e9758c47fd 100644 --- a/crates/libs/core/src/interface.rs +++ b/crates/libs/core/src/interface.rs @@ -1,4 +1,7 @@ use super::*; +use core::ffi::c_void; +use core::marker::PhantomData; +use core::ptr::NonNull; /// Provides low-level access to an interface vtable. /// @@ -38,14 +41,14 @@ pub unsafe trait Interface: Sized + Clone { /// Returns the raw COM interface pointer. The resulting pointer continues to be owned by the `Interface` implementation. #[inline(always)] - fn as_raw(&self) -> *mut std::ffi::c_void { + fn as_raw(&self) -> *mut c_void { // SAFETY: implementors of this trait must guarantee that the implementing type has a pointer in-memory representation unsafe { std::mem::transmute_copy(self) } } /// Returns the raw COM interface pointer and releases ownership. It the caller's responsibility to release the COM interface pointer. #[inline(always)] - fn into_raw(self) -> *mut std::ffi::c_void { + fn into_raw(self) -> *mut c_void { // SAFETY: implementors of this trait must guarantee that the implementing type has a pointer in-memory representation let raw = self.as_raw(); std::mem::forget(self); @@ -58,7 +61,7 @@ pub unsafe trait Interface: Sized + Clone { /// /// The `raw` pointer must be owned by the caller and represent a valid COM interface pointer. In other words, /// it must point to a vtable beginning with the `IUnknown` function pointers and match the vtable of `Interface`. - unsafe fn from_raw(raw: *mut std::ffi::c_void) -> Self { + unsafe fn from_raw(raw: *mut c_void) -> Self { std::mem::transmute_copy(&raw) } @@ -69,7 +72,7 @@ pub unsafe trait Interface: Sized + Clone { /// The `raw` pointer must be a valid COM interface pointer. In other words, it must point to a vtable /// beginning with the `IUnknown` function pointers and match the vtable of `Interface`. #[inline(always)] - unsafe fn from_raw_borrowed(raw: &*mut std::ffi::c_void) -> Option<&Self> { + unsafe fn from_raw_borrowed(raw: &*mut c_void) -> Option<&Self> { if raw.is_null() { None } else { @@ -117,17 +120,93 @@ pub unsafe trait Interface: Sized + Clone { /// /// `interface` must be a non-null, valid pointer for writing an interface pointer. #[inline(always)] - unsafe fn query(&self, iid: *const GUID, interface: *mut *mut std::ffi::c_void) -> HRESULT { + unsafe fn query(&self, iid: *const GUID, interface: *mut *mut c_void) -> HRESULT { if Self::UNKNOWN { (self.assume_vtable::().QueryInterface)(self.as_raw(), iid, interface) } else { panic!("Non-COM interfaces cannot be queried.") } } + + /// Creates an `InterfaceRef` for this reference. The `InterfaceRef` tracks lifetimes statically, + /// and eliminates the need for dynamic reference count adjustments (AddRef/Release). + fn to_ref(&self) -> InterfaceRef<'_, Self> { + InterfaceRef::from_interface(self) + } } /// # Safety #[doc(hidden)] -pub unsafe fn from_raw_borrowed(raw: &*mut std::ffi::c_void) -> Option<&T> { +pub unsafe fn from_raw_borrowed(raw: &*mut c_void) -> Option<&T> { T::from_raw_borrowed(raw) } + +/// This has the same memory representation as `IFoo`, but represents a borrowed interface pointer. +/// +/// This type has no `Drop` impl; it does not AddRef/Release the given interface. However, because +/// it has a lifetime parameter, it always represents a non-null pointer to an interface. +#[repr(transparent)] +pub struct InterfaceRef<'a, I>(NonNull, PhantomData<&'a I>); + +impl<'a, I> Copy for InterfaceRef<'a, I> {} + +impl<'a, I> Clone for InterfaceRef<'a, I> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, I: core::fmt::Debug + Interface> core::fmt::Debug for InterfaceRef<'a, I> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + ::fmt(&**self, f) + } +} + +impl<'a, I: Interface> InterfaceRef<'a, I> { + /// Creates an `InterfaceRef` from a raw pointer. _This is extremely dangerous, since there + /// is no lifetime tracking at all!_ + /// + /// # Safety + /// The caller must guarantee that the `'a` lifetime parameter is bound by context to a correct + /// lifetime. + #[inline(always)] + pub unsafe fn from_raw(ptr: NonNull) -> Self { + Self(ptr, PhantomData) + } + + /// Creates an `InterfaceRef` from an interface reference. This safely associates the lifetime + /// of the interface reference with the `'a` parameter of `InterfaceRef`. This allows for + /// lifetime checking _without_ calling AddRef/Release on the underlying lifetime, which can + /// improve efficiency. + #[inline(always)] + pub fn from_interface(interface: &I) -> Self { + unsafe { + // SAFETY: new_unchecked() should be valid because Interface::as_raw should always + // return a non-null pointer. + Self(NonNull::new_unchecked(interface.as_raw()), PhantomData) + } + } + + /// Calls AddRef on the underlying COM interface and returns an "owned" (counted) reference. + #[inline(always)] + pub fn to_owned(self) -> I { + let interface: &I = &*self; + interface.clone() + } +} + +impl<'a, 'i: 'a, I: Interface> From<&'i I> for InterfaceRef<'a, I> { + #[inline(always)] + fn from(interface: &'a I) -> InterfaceRef<'a, I> { + InterfaceRef::from_interface(interface) + } +} + +impl<'a, I: Interface> core::ops::Deref for InterfaceRef<'a, I> { + type Target = I; + + #[inline(always)] + fn deref(&self) -> &I { + unsafe { core::mem::transmute(self) } + } +} diff --git a/crates/libs/core/src/ref.rs b/crates/libs/core/src/ref.rs index b1025fb985..b8d82a8cca 100644 --- a/crates/libs/core/src/ref.rs +++ b/crates/libs/core/src/ref.rs @@ -1,8 +1,9 @@ use super::*; +use core::marker::PhantomData; /// A borrowed type with the same memory layout as the type itself that can be used to construct ABI-compatible function signatures. #[repr(transparent)] -pub struct Ref<'a, T: Type>(T::Abi, std::marker::PhantomData<&'a T>); +pub struct Ref<'a, T: Type>(T::Abi, PhantomData<&'a T>); impl<'a, T: Type, Abi = *mut std::ffi::c_void>> Ref<'a, T> { /// Converts the argument to a [Result<&T>] reference. diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 7908d27371..a0dd2984e0 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -99,10 +99,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } impl #generics ::windows_core::imp::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints { - fn get_interface(&self) -> ::windows_core::Ref<'_, #interface_ident> { + fn get_interface(&self) -> ::windows_core::InterfaceRef<'_, #interface_ident> { unsafe { - let vtable_ptr = &self.vtables.#offset; - ::core::mem::transmute(vtable_ptr) + let interface_ptr = &self.vtables.#offset; + ::core::mem::transmute(interface_ptr) } } } diff --git a/crates/tests/implement/tests/com_object.rs b/crates/tests/implement/tests/com_object.rs index ab3ee8d897..677d5ce60b 100644 --- a/crates/tests/implement/tests/com_object.rs +++ b/crates/tests/implement/tests/com_object.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use std::sync::Arc; use windows_core::{implement, interface, ComObject, IUnknown, IUnknown_Vtbl}; @@ -13,6 +14,24 @@ struct MyApp { tombstone: Arc, } +impl Borrow for MyApp { + fn borrow(&self) -> &u32 { + &self.x + } +} + +impl core::fmt::Debug for MyApp { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "x = {}", self.x) + } +} + +impl core::fmt::Display for MyApp { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "x = {}", self.x) + } +} + impl Default for MyApp { fn default() -> Self { Self { @@ -22,6 +41,27 @@ impl Default for MyApp { } } +impl std::hash::Hash for MyApp { + fn hash(&self, state: &mut H) { + self.x.hash(state); + } +} + +impl PartialEq for MyApp { + fn eq(&self, other: &u32) -> bool { + self.x == *other + } +} + +impl PartialEq for MyApp { + fn eq(&self, other: &MyApp) -> bool { + self.x == other.x + } +} + +impl Eq for MyApp {} + +/// This lets us detect when an object has been dropped. #[derive(Default)] struct Tombstone { cell: AtomicBool, @@ -75,7 +115,7 @@ fn basic() { assert_eq!(unsafe { ifoo.get_x() }, 42); // check lifetimes - let tombstone = app.get().tombstone.clone(); + let tombstone = app.tombstone.clone(); assert!(!tombstone.is_dead()); drop(app); @@ -91,10 +131,10 @@ fn basic() { #[test] fn casting() { let app: ComObject = MyApp::new(42); - let tombstone = app.get().tombstone.clone(); + let tombstone = app.tombstone.clone(); let ifoo: IFoo = app.cast().unwrap(); - assert_eq!(unsafe { app.get().get_x() }, 42); + assert_eq!(unsafe { app.get_x() }, 42); // check lifetimes assert!(!tombstone.is_dead()); @@ -120,12 +160,12 @@ fn clone() { #[test] fn get_mut() { let mut app: ComObject = MyApp::new(42); - assert_eq!(app.get().get_x_direct(), 42); + assert_eq!(app.get_x_direct(), 42); // refcount = 1 app.get_mut().unwrap().set_x(50); - assert_eq!(app.get().get_x_direct(), 50); + assert_eq!(app.get_x_direct(), 50); let app2 = app.clone(); // refcount = 2 @@ -140,7 +180,7 @@ fn get_mut() { #[test] fn try_take() { let app: ComObject = MyApp::new(42); - let tombstone = app.get().tombstone.clone(); + let tombstone = app.tombstone.clone(); // refcount = 1 let app2 = app.clone(); @@ -171,11 +211,32 @@ fn try_take() { } #[test] -fn object_interfaces() { +fn as_interface() { + let app = MyApp::new(42); + let tombstone = app.tombstone.clone(); + + let ifoo = app.as_interface::(); + assert_eq!(unsafe { ifoo.get_x() }, 42); + assert!(!tombstone.is_dead()); + + drop(app); + assert!(tombstone.is_dead()); +} + +#[test] +fn to_interface() { let app = MyApp::new(42); - let ifoo_ref = app.borrow_interface::(); - let ifoo = ifoo_ref.ok().unwrap(); + let tombstone = app.tombstone.clone(); + + let ifoo = app.to_interface::(); assert_eq!(unsafe { ifoo.get_x() }, 42); + assert!(!tombstone.is_dead()); + + drop(app); + assert!(!tombstone.is_dead()); + + drop(ifoo); + assert!(tombstone.is_dead()); } #[test] @@ -199,3 +260,33 @@ fn construct_with_into() { consume(MyApp::default().into()) } + +#[test] +fn debug() { + let app = MyApp::new(100); + let s = format!("{:?}", app); + assert_eq!(s, "x = 100"); +} + +#[test] +fn display() { + let app = MyApp::new(200); + let s = format!("{}", app); + assert_eq!(s, "x = 200"); +} + +#[cfg(todo)] +#[test] +fn hashable() { + use std::collections::HashMap; + + let mut map: HashMap, &'static str> = HashMap::new(); + + map.insert(MyApp::new(100), "hello"); + map.insert(MyApp::new(200), "world"); + + let i: &&str = map.get(&100).unwrap(); + assert_eq!(*i, "hello"); + + assert!(map.get(&333).is_none()); +}