Skip to content

Commit

Permalink
more features for ComObject
Browse files Browse the repository at this point in the history
  • Loading branch information
Arlie Davis committed May 15, 2024
1 parent dbcb9cd commit 95833f3
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 51 deletions.
136 changes: 106 additions & 30 deletions crates/libs/core/src/com_object.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, Ref};
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, InterfaceRef};
use core::ffi::c_void;
use core::mem::ManuallyDrop;
use core::ptr::NonNull;
use std::borrow::Borrow;

// This is implemented on user types that are marked with #[implement].
#[allow(missing_docs)]
Expand All @@ -13,10 +15,10 @@ pub unsafe trait ComImpl {
/// This trait is implemented by ComObject implementation obejcts (e.g. `MyApp_Impl`).
pub trait ComObjectInterface<I: Interface> {
/// Gets a borrowed interface on the ComObject.
fn get_interface(&self) -> Ref<'_, I>;
fn get_interface(&self) -> InterfaceRef<'_, I>;
}

/// A counted pointed to a type that implements COM interfaces, where the object has been
/// A counted pointer to a type that implements COM interfaces, where the object has been
/// placed in the heap (boxed).
///
/// This type exists so that you can place an object into the heap and query for COM interfaces,
Expand All @@ -25,26 +27,11 @@ pub struct ComObject<T: ComImpl> {
ptr: NonNull<T::Impl>,
}

impl<T: Default + ComImpl> Default for ComObject<T> {
fn default() -> Self {
Self::new(T::default())
}
}

impl<T: ComImpl> Drop for ComObject<T> {
fn drop(&mut self) {
unsafe {
T::Impl::Release(self.ptr.as_ptr());
}
}
}

impl<T: ComImpl> ComObject<T> {
/// Allocates a heap cell (box) and moves `obj` into it. Returns a counted pointer to `obj`.
pub fn new(value: T) -> Self {
unsafe {
let box_ = T::Impl::new_box(value);

Self { ptr: NonNull::new_unchecked(Box::into_raw(box_)) }
}
}
Expand All @@ -65,6 +52,12 @@ impl<T: ComImpl> ComObject<T> {
unsafe { self.ptr.as_ref() }
}

// Note that we _do not_ provide a way to get a mutable reference to the outer box.
// It's ok to return &mut T, but not &mut T::Impl. That would allow someone to replace the
// contents of the entire object (box and reference count), which could lead to UB.
// This could maybe be solved by returning Pin<&mut T::Impl>, but that requires some
// additional thinking.

/// Gets a mutable reference to the object stored in the box, if the reference count
/// is exactly 1. If there are multiple references to this object then this returns `None`.
#[inline(always)]
Expand All @@ -80,8 +73,8 @@ impl<T: ComImpl> ComObject<T> {
}

/// If this object has only a single object reference (i.e. this `ComObject` is the only
/// reference to the heap allocation), then this method will destroy the `ComObject` and will
/// return the inner implementation object, wrapped in `Ok`.
/// reference to the heap allocation), then this method will extract the inner `T`
/// (and return it in an `Ok`) and then free the heap allocation.
///
/// If there is more than one reference to this object, then this returns `Err(self)`.
#[inline(always)]
Expand All @@ -108,28 +101,50 @@ impl<T: ComImpl> ComObject<T> {
}
}

/// Gets a reference to an interface that is implemented by this ComObject.
/// Gets a borrowed reference to an interface that is implemented by this ComObject.
///
/// The returned reference does not have an additional reference count.
/// You can AddRef it by calling clone().
pub fn borrow_interface<I: Interface>(&self) -> Ref<'_, I>
/// You can AddRef it by calling to_owned().
pub fn as_interface<I: Interface>(&self) -> InterfaceRef<'_, I>
where
T::Impl: ComObjectInterface<I>,
{
self.get_box().get_interface()
}

/// Gets a counted reference to an interface that is implemented by this ComObject.
pub fn get_interface<I: Interface>(&self) -> I
/// Gets an owned (counted) reference to an interface that is implemented by this ComObject.
pub fn to_interface<I: Interface>(&self) -> I
where
I: Clone,
T::Impl: ComObjectInterface<I>,
{
let interface_ref: Ref<'_, I> = self.get_box().get_interface();
let interface_inner: &I = interface_ref.ok().unwrap();
interface_inner.clone()
self.as_interface::<I>().to_owned()
}

/// Converts this `ComObject` into an interface that it implements.
pub fn into_interface<I: Interface>(self) -> I
where
T::Impl: ComObjectInterface<I>,
{
unsafe {
let raw: *mut c_void = self.get_box().get_interface().as_raw();
core::mem::forget(self);
I::from_raw(raw)
}
}
}

impl<T: ComImpl + Default> Default for ComObject<T> {
fn default() -> Self {
Self::new(T::default())
}
}

impl<T: ComImpl> Drop for ComObject<T> {
fn drop(&mut self) {
unsafe {
T::Impl::Release(self.ptr.as_ptr());
}
}
}

impl<T: ComImpl> Clone for ComObject<T> {
Expand All @@ -152,7 +167,7 @@ where
}
}

impl<T: ComImpl<Impl = T>> core::ops::Deref for ComObject<T> {
impl<T: ComImpl> core::ops::Deref for ComObject<T> {
type Target = T::Impl;

#[inline(always)]
Expand All @@ -161,8 +176,69 @@ impl<T: ComImpl<Impl = T>> core::ops::Deref for ComObject<T> {
}
}

// There is no DerefMut implementation because we cannot statically guarantee
// that the reference count is 1, which is a requirement for getting exclusive
// access to the contents of the object. Use get_mut() for dynamically-checked
// exclusive access.

impl<T: ComImpl> From<T> for ComObject<T> {
fn from(value: T) -> ComObject<T> {
ComObject::new(value)
}
}

// Delegate hashing, if implemented.
impl<T: ComImpl + core::hash::Hash> core::hash::Hash for ComObject<T> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.get().hash(state);
}
}

// If T is Send (or Sync) then the ComObject<T> is also Send (or Sync).
// Since the actual object storage is in the heap, the object is never moved.
unsafe impl<T: ComImpl + Sync> Send for ComObject<T> {}
unsafe impl<T: ComImpl + Sync> Sync for ComObject<T> {}

impl<T: ComImpl + PartialEq> PartialEq for ComObject<T> {
fn eq(&self, other: &ComObject<T>) -> bool {
let inner_self: &T = self.get();
let other_self: &T = other.get();
inner_self == other_self
}
}

impl<T: ComImpl + Eq> Eq for ComObject<T> {}

impl<T: ComImpl + PartialOrd> PartialOrd for ComObject<T> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
let inner_self: &T = self.get();
let other_self: &T = other.get();
<T as PartialOrd>::partial_cmp(inner_self, other_self)
}
}

impl<T: ComImpl + Ord> Ord for ComObject<T> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
let inner_self: &T = self.get();
let other_self: &T = other.get();
<T as Ord>::cmp(inner_self, other_self)
}
}

impl<T: ComImpl + core::fmt::Debug> core::fmt::Debug for ComObject<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<T as core::fmt::Debug>::fmt(self.get(), f)
}
}

impl<T: ComImpl + core::fmt::Display> core::fmt::Display for ComObject<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<T as core::fmt::Display>::fmt(self.get(), f)
}
}

impl<T: ComImpl> Borrow<T> for ComObject<T> {
fn borrow(&self) -> &T {
self.get()
}
}
4 changes: 2 additions & 2 deletions crates/libs/core/src/imp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub use define_interface;
pub use std::boxed::Box;

#[doc(hidden)]
pub const E_POINTER: crate::HRESULT = crate::HRESULT(-2147467261);
pub const E_POINTER: crate::HRESULT = crate::HRESULT(0x80004003u32 as i32);

#[doc(hidden)]
pub const E_NOINTERFACE: crate::HRESULT = crate::HRESULT(-2147467262);
pub const E_NOINTERFACE: crate::HRESULT = crate::HRESULT(0x80004002u32 as i32);
91 changes: 85 additions & 6 deletions crates/libs/core/src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::*;
use core::ffi::c_void;
use core::marker::PhantomData;
use core::ptr::NonNull;

/// Provides low-level access to an interface vtable.
///
Expand Down Expand Up @@ -38,14 +41,14 @@ pub unsafe trait Interface: Sized + Clone {

/// Returns the raw COM interface pointer. The resulting pointer continues to be owned by the `Interface` implementation.
#[inline(always)]
fn as_raw(&self) -> *mut std::ffi::c_void {
fn as_raw(&self) -> *mut c_void {
// SAFETY: implementors of this trait must guarantee that the implementing type has a pointer in-memory representation
unsafe { std::mem::transmute_copy(self) }
}

/// Returns the raw COM interface pointer and releases ownership. It the caller's responsibility to release the COM interface pointer.
#[inline(always)]
fn into_raw(self) -> *mut std::ffi::c_void {
fn into_raw(self) -> *mut c_void {
// SAFETY: implementors of this trait must guarantee that the implementing type has a pointer in-memory representation
let raw = self.as_raw();
std::mem::forget(self);
Expand All @@ -58,7 +61,7 @@ pub unsafe trait Interface: Sized + Clone {
///
/// The `raw` pointer must be owned by the caller and represent a valid COM interface pointer. In other words,
/// it must point to a vtable beginning with the `IUnknown` function pointers and match the vtable of `Interface`.
unsafe fn from_raw(raw: *mut std::ffi::c_void) -> Self {
unsafe fn from_raw(raw: *mut c_void) -> Self {
std::mem::transmute_copy(&raw)
}

Expand All @@ -69,7 +72,7 @@ pub unsafe trait Interface: Sized + Clone {
/// The `raw` pointer must be a valid COM interface pointer. In other words, it must point to a vtable
/// beginning with the `IUnknown` function pointers and match the vtable of `Interface`.
#[inline(always)]
unsafe fn from_raw_borrowed(raw: &*mut std::ffi::c_void) -> Option<&Self> {
unsafe fn from_raw_borrowed(raw: &*mut c_void) -> Option<&Self> {
if raw.is_null() {
None
} else {
Expand Down Expand Up @@ -117,17 +120,93 @@ pub unsafe trait Interface: Sized + Clone {
///
/// `interface` must be a non-null, valid pointer for writing an interface pointer.
#[inline(always)]
unsafe fn query(&self, iid: *const GUID, interface: *mut *mut std::ffi::c_void) -> HRESULT {
unsafe fn query(&self, iid: *const GUID, interface: *mut *mut c_void) -> HRESULT {
if Self::UNKNOWN {
(self.assume_vtable::<IUnknown>().QueryInterface)(self.as_raw(), iid, interface)
} else {
panic!("Non-COM interfaces cannot be queried.")
}
}

/// Creates an `InterfaceRef` for this reference. The `InterfaceRef` tracks lifetimes statically,
/// and eliminates the need for dynamic reference count adjustments (AddRef/Release).
fn to_ref(&self) -> InterfaceRef<'_, Self> {
InterfaceRef::from_interface(self)
}
}

/// # Safety
#[doc(hidden)]
pub unsafe fn from_raw_borrowed<T: Interface>(raw: &*mut std::ffi::c_void) -> Option<&T> {
pub unsafe fn from_raw_borrowed<T: Interface>(raw: &*mut c_void) -> Option<&T> {
T::from_raw_borrowed(raw)
}

/// This has the same memory representation as `IFoo`, but represents a borrowed interface pointer.
///
/// This type has no `Drop` impl; it does not AddRef/Release the given interface. However, because
/// it has a lifetime parameter, it always represents a non-null pointer to an interface.
#[repr(transparent)]
pub struct InterfaceRef<'a, I>(NonNull<c_void>, PhantomData<&'a I>);

impl<'a, I> Copy for InterfaceRef<'a, I> {}

impl<'a, I> Clone for InterfaceRef<'a, I> {
fn clone(&self) -> Self {
*self
}
}

impl<'a, I: core::fmt::Debug + Interface> core::fmt::Debug for InterfaceRef<'a, I> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
<I as core::fmt::Debug>::fmt(&**self, f)
}
}

impl<'a, I: Interface> InterfaceRef<'a, I> {
/// Creates an `InterfaceRef` from a raw pointer. _This is extremely dangerous, since there
/// is no lifetime tracking at all!_
///
/// # Safety
/// The caller must guarantee that the `'a` lifetime parameter is bound by context to a correct
/// lifetime.
#[inline(always)]
pub unsafe fn from_raw(ptr: NonNull<c_void>) -> Self {
Self(ptr, PhantomData)
}

/// Creates an `InterfaceRef` from an interface reference. This safely associates the lifetime
/// of the interface reference with the `'a` parameter of `InterfaceRef`. This allows for
/// lifetime checking _without_ calling AddRef/Release on the underlying lifetime, which can
/// improve efficiency.
#[inline(always)]
pub fn from_interface(interface: &I) -> Self {
unsafe {
// SAFETY: new_unchecked() should be valid because Interface::as_raw should always
// return a non-null pointer.
Self(NonNull::new_unchecked(interface.as_raw()), PhantomData)
}
}

/// Calls AddRef on the underlying COM interface and returns an "owned" (counted) reference.
#[inline(always)]
pub fn to_owned(self) -> I {
let interface: &I = &*self;
interface.clone()
}
}

impl<'a, 'i: 'a, I: Interface> From<&'i I> for InterfaceRef<'a, I> {
#[inline(always)]
fn from(interface: &'a I) -> InterfaceRef<'a, I> {
InterfaceRef::from_interface(interface)
}
}

impl<'a, I: Interface> core::ops::Deref for InterfaceRef<'a, I> {
type Target = I;

#[inline(always)]
fn deref(&self) -> &I {
unsafe { core::mem::transmute(self) }
}
}
3 changes: 2 additions & 1 deletion crates/libs/core/src/ref.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::*;
use core::marker::PhantomData;

/// A borrowed type with the same memory layout as the type itself that can be used to construct ABI-compatible function signatures.
#[repr(transparent)]
pub struct Ref<'a, T: Type<T>>(T::Abi, std::marker::PhantomData<&'a T>);
pub struct Ref<'a, T: Type<T>>(T::Abi, PhantomData<&'a T>);

impl<'a, T: Type<T, Default = Option<T>, Abi = *mut std::ffi::c_void>> Ref<'a, T> {
/// Converts the argument to a [Result<&T>] reference.
Expand Down
6 changes: 3 additions & 3 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}

impl #generics ::windows_core::imp::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints {
fn get_interface(&self) -> ::windows_core::Ref<'_, #interface_ident> {
fn get_interface(&self) -> ::windows_core::InterfaceRef<'_, #interface_ident> {
unsafe {
let vtable_ptr = &self.vtables.#offset;
::core::mem::transmute(vtable_ptr)
let interface_ptr = &self.vtables.#offset;
::core::mem::transmute(interface_ptr)
}
}
}
Expand Down
Loading

0 comments on commit 95833f3

Please sign in to comment.