Skip to content

Commit

Permalink
implement get_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Arlie Davis committed May 15, 2024
1 parent 2dbb95c commit dbcb9cd
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 10 deletions.
46 changes: 41 additions & 5 deletions crates/libs/core/src/com_object.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface};
use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, Ref};
use core::mem::ManuallyDrop;
use core::ptr::NonNull;

Expand All @@ -9,6 +9,13 @@ pub unsafe trait ComImpl {
type Impl: IUnknownImpl<Impl = Self>;
}

/// Describes the COM interfaces that a specific ComObject implements.
/// This trait is implemented by ComObject implementation obejcts (e.g. `MyApp_Impl`).
pub trait ComObjectInterface<I: Interface> {
/// Gets a borrowed interface on the ComObject.
fn get_interface(&self) -> Ref<'_, I>;
}

/// A counted pointed to a type that implements COM interfaces, where the object has been
/// placed in the heap (boxed).
///
Expand Down Expand Up @@ -49,7 +56,13 @@ impl<T: ComImpl> ComObject<T> {
/// this method to explicitly get a reference to the contents.
#[inline(always)]
pub fn get(&self) -> &T {
unsafe { self.ptr.as_ref().get_impl() }
self.get_box().get_impl()
}

/// Gets a reference to the shared object's heap box.
#[inline(always)]
pub fn get_box(&self) -> &T::Impl {
unsafe { self.ptr.as_ref() }
}

/// Gets a mutable reference to the object stored in the box, if the reference count
Expand Down Expand Up @@ -94,6 +107,29 @@ impl<T: ComImpl> ComObject<T> {
unknown.cast()
}
}

/// Gets a reference to an interface that is implemented by this ComObject.
///
/// The returned reference does not have an additional reference count.
/// You can AddRef it by calling clone().
pub fn borrow_interface<I: Interface>(&self) -> Ref<'_, I>
where
T::Impl: ComObjectInterface<I>,
{
self.get_box().get_interface()
}

/// Gets a counted reference to an interface that is implemented by this ComObject.
pub fn get_interface<I: Interface>(&self) -> I
where
I: Clone,
T::Impl: ComObjectInterface<I>,
{
let interface_ref: Ref<'_, I> = self.get_box().get_interface();
let interface_inner: &I = interface_ref.ok().unwrap();
interface_inner.clone()
}

}

impl<T: ComImpl> Clone for ComObject<T> {
Expand All @@ -116,12 +152,12 @@ where
}
}

impl<T: ComImpl> core::ops::Deref for ComObject<T> {
type Target = T;
impl<T: ComImpl<Impl = T>> core::ops::Deref for ComObject<T> {
type Target = T::Impl;

#[inline(always)]
fn deref(&self) -> &Self::Target {
self.get()
self.get_box()
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/libs/core/src/imp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod sha1;
mod waiter;
mod weak_ref_count;

pub use crate::com_object::ComImpl;
pub use crate::com_object::{ComImpl, ComObjectInterface};
pub use bindings::*;
pub use can_into::*;
pub use com_bindings::*;
Expand Down
20 changes: 20 additions & 0 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
unsafe { ::core::mem::transmute(vtable_ptr) }
}
}

impl #generics ::windows_core::imp::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints {
fn get_interface(&self) -> ::windows_core::Ref<'_, #interface_ident> {
unsafe {
let vtable_ptr = &self.vtables.#offset;
::core::mem::transmute(vtable_ptr)
}
}
}

impl #generics ::windows_core::AsImpl<#original_ident::#generics> for #interface_ident where #constraints {
// SAFETY: the offset is guranteed to be in bounds, and the implementation struct
// is guaranteed to live at least as long as `self`.
Expand Down Expand Up @@ -256,6 +266,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
::core::ptr::NonNull::new_unchecked(::core::ptr::addr_of!((*this).this) as *const #original_ident::#generics as *mut #original_ident::#generics)
}
}

impl #generics ::core::ops::Deref for #impl_ident::#generics where #constraints {
type Target = #original_ident::#generics;

#[inline(always)]
fn deref(&self) -> &Self::Target {
::windows_core::IUnknownImpl::get_impl(self)
}
}

#(#conversions)*
};

Expand Down
17 changes: 13 additions & 4 deletions crates/tests/implement/tests/com_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ fn basic() {
assert_eq!(unsafe { ifoo.get_x() }, 42);

// check lifetimes
let tombstone = app.tombstone.clone();
let tombstone = app.get().tombstone.clone();
assert!(!tombstone.is_dead());

drop(app);
Expand All @@ -91,11 +91,12 @@ fn basic() {
#[test]
fn casting() {
let app: ComObject<MyApp> = MyApp::new(42);
let tombstone = app.get().tombstone.clone();

let ifoo: IFoo = app.cast().unwrap();
assert_eq!(unsafe { app.get_x() }, 42);
assert_eq!(unsafe { app.get().get_x() }, 42);

// check lifetimes
let tombstone = app.tombstone.clone();
assert!(!tombstone.is_dead());

drop(app);
Expand Down Expand Up @@ -139,7 +140,7 @@ fn get_mut() {
#[test]
fn try_take() {
let app: ComObject<MyApp> = MyApp::new(42);
let tombstone = app.tombstone.clone();
let tombstone = app.get().tombstone.clone();
// refcount = 1

let app2 = app.clone();
Expand Down Expand Up @@ -169,6 +170,14 @@ fn try_take() {
}
}

#[test]
fn object_interfaces() {
let app = MyApp::new(42);
let ifoo_ref = app.borrow_interface::<IFoo>();
let ifoo = ifoo_ref.ok().unwrap();
assert_eq!(unsafe { ifoo.get_x() }, 42);
}

#[test]
fn construct_with_com_object_new() {
// Tests that we can construct using ComObject::new().
Expand Down

0 comments on commit dbcb9cd

Please sign in to comment.