diff --git a/crates/libs/core/src/com_object.rs b/crates/libs/core/src/com_object.rs index 2e3176dfc0..cbaf496e2b 100644 --- a/crates/libs/core/src/com_object.rs +++ b/crates/libs/core/src/com_object.rs @@ -1,4 +1,4 @@ -use crate::{AsImpl, IUnknown, IUnknownImpl, Interface}; +use crate::{AsImpl, IUnknown, IUnknownImpl, Interface, Ref}; use core::mem::ManuallyDrop; use core::ptr::NonNull; @@ -9,6 +9,13 @@ pub unsafe trait ComImpl { type Impl: IUnknownImpl; } +/// Describes the COM interfaces that a specific ComObject implements. +/// This trait is implemented by ComObject implementation obejcts (e.g. `MyApp_Impl`). +pub trait ComObjectInterface { + /// Gets a borrowed interface on the ComObject. + fn get_interface(&self) -> Ref<'_, I>; +} + /// A counted pointed to a type that implements COM interfaces, where the object has been /// placed in the heap (boxed). /// @@ -49,7 +56,13 @@ impl ComObject { /// this method to explicitly get a reference to the contents. #[inline(always)] pub fn get(&self) -> &T { - unsafe { self.ptr.as_ref().get_impl() } + self.get_box().get_impl() + } + + /// Gets a reference to the shared object's heap box. + #[inline(always)] + pub fn get_box(&self) -> &T::Impl { + unsafe { self.ptr.as_ref() } } /// Gets a mutable reference to the object stored in the box, if the reference count @@ -94,6 +107,29 @@ impl ComObject { unknown.cast() } } + + /// Gets a reference to an interface that is implemented by this ComObject. + /// + /// The returned reference does not have an additional reference count. + /// You can AddRef it by calling clone(). + pub fn borrow_interface(&self) -> Ref<'_, I> + where + T::Impl: ComObjectInterface, + { + self.get_box().get_interface() + } + + /// Gets a counted reference to an interface that is implemented by this ComObject. + pub fn get_interface(&self) -> I + where + I: Clone, + T::Impl: ComObjectInterface, + { + let interface_ref: Ref<'_, I> = self.get_box().get_interface(); + let interface_inner: &I = interface_ref.ok().unwrap(); + interface_inner.clone() + } + } impl Clone for ComObject { @@ -116,12 +152,12 @@ where } } -impl core::ops::Deref for ComObject { - type Target = T; +impl> core::ops::Deref for ComObject { + type Target = T::Impl; #[inline(always)] fn deref(&self) -> &Self::Target { - self.get() + self.get_box() } } diff --git a/crates/libs/core/src/imp/mod.rs b/crates/libs/core/src/imp/mod.rs index c83f8b299c..b0186cf248 100644 --- a/crates/libs/core/src/imp/mod.rs +++ b/crates/libs/core/src/imp/mod.rs @@ -10,7 +10,7 @@ mod sha1; mod waiter; mod weak_ref_count; -pub use crate::com_object::ComImpl; +pub use crate::com_object::{ComImpl, ComObjectInterface}; pub use bindings::*; pub use can_into::*; pub use com_bindings::*; diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index b63d9cba48..7908d27371 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -97,6 +97,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: unsafe { ::core::mem::transmute(vtable_ptr) } } } + + impl #generics ::windows_core::imp::ComObjectInterface<#interface_ident> for #impl_ident::#generics where #constraints { + fn get_interface(&self) -> ::windows_core::Ref<'_, #interface_ident> { + unsafe { + let vtable_ptr = &self.vtables.#offset; + ::core::mem::transmute(vtable_ptr) + } + } + } + impl #generics ::windows_core::AsImpl<#original_ident::#generics> for #interface_ident where #constraints { // SAFETY: the offset is guranteed to be in bounds, and the implementation struct // is guaranteed to live at least as long as `self`. @@ -256,6 +266,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: ::core::ptr::NonNull::new_unchecked(::core::ptr::addr_of!((*this).this) as *const #original_ident::#generics as *mut #original_ident::#generics) } } + + impl #generics ::core::ops::Deref for #impl_ident::#generics where #constraints { + type Target = #original_ident::#generics; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + ::windows_core::IUnknownImpl::get_impl(self) + } + } + #(#conversions)* }; diff --git a/crates/tests/implement/tests/com_object.rs b/crates/tests/implement/tests/com_object.rs index 6892ea04f7..ab3ee8d897 100644 --- a/crates/tests/implement/tests/com_object.rs +++ b/crates/tests/implement/tests/com_object.rs @@ -75,7 +75,7 @@ fn basic() { assert_eq!(unsafe { ifoo.get_x() }, 42); // check lifetimes - let tombstone = app.tombstone.clone(); + let tombstone = app.get().tombstone.clone(); assert!(!tombstone.is_dead()); drop(app); @@ -91,11 +91,12 @@ fn basic() { #[test] fn casting() { let app: ComObject = MyApp::new(42); + let tombstone = app.get().tombstone.clone(); + let ifoo: IFoo = app.cast().unwrap(); - assert_eq!(unsafe { app.get_x() }, 42); + assert_eq!(unsafe { app.get().get_x() }, 42); // check lifetimes - let tombstone = app.tombstone.clone(); assert!(!tombstone.is_dead()); drop(app); @@ -139,7 +140,7 @@ fn get_mut() { #[test] fn try_take() { let app: ComObject = MyApp::new(42); - let tombstone = app.tombstone.clone(); + let tombstone = app.get().tombstone.clone(); // refcount = 1 let app2 = app.clone(); @@ -169,6 +170,14 @@ fn try_take() { } } +#[test] +fn object_interfaces() { + let app = MyApp::new(42); + let ifoo_ref = app.borrow_interface::(); + let ifoo = ifoo_ref.ok().unwrap(); + assert_eq!(unsafe { ifoo.get_x() }, 42); +} + #[test] fn construct_with_com_object_new() { // Tests that we can construct using ComObject::new().