From 55590998467d397bde20538d851743871a8bc356 Mon Sep 17 00:00:00 2001 From: Kenny Kerr Date: Thu, 9 May 2024 15:31:13 -0500 Subject: [PATCH] Add `Ref` and `OutRef` to enhance COM authoring support (#3025) --- crates/libs/bindgen/src/rust/handles.rs | 2 +- crates/libs/bindgen/src/rust/interfaces.rs | 4 +- crates/libs/core/src/imp/can_into.rs | 5 + crates/libs/core/src/imp/com_bindings.rs | 6 +- crates/libs/core/src/imp/mod.rs | 6 +- crates/libs/core/src/lib.rs | 8 + crates/libs/core/src/out_param.rs | 56 +++++++ crates/libs/core/src/out_ref.rs | 25 ++++ crates/libs/core/src/param.rs | 28 +--- crates/libs/core/src/param_value.rs | 23 +++ crates/libs/core/src/ref.rs | 12 ++ crates/libs/interface/Cargo.toml | 2 +- crates/libs/interface/src/lib.rs | 141 ++++++++++++++++-- .../src/Windows/Foundation/Collections/mod.rs | 56 +++---- .../windows/src/Windows/Foundation/mod.rs | 30 ++-- .../src/Windows/Win32/Foundation/mod.rs | 6 +- .../src/Windows/Win32/Graphics/Gdi/mod.rs | 12 +- .../Win32/Security/Cryptography/mod.rs | 12 +- .../Win32/UI/WindowsAndMessaging/mod.rs | 2 +- .../json_validator_winrt/src/lib.rs | 4 +- crates/tests/component/src/lib.rs | 17 +-- crates/tests/interface_core/tests/ref.rs | 133 +++++++++++++++++ crates/tests/interface_core/tests/result.rs | 44 ++++++ crates/tests/riddle/src/generic_interfaces.rs | 18 +-- crates/tests/standalone/src/b_calendar.rs | 14 +- crates/tests/standalone/src/b_uri.rs | 14 +- 26 files changed, 536 insertions(+), 144 deletions(-) create mode 100644 crates/libs/core/src/imp/can_into.rs create mode 100644 crates/libs/core/src/out_param.rs create mode 100644 crates/libs/core/src/out_ref.rs create mode 100644 crates/libs/core/src/param_value.rs create mode 100644 crates/libs/core/src/ref.rs create mode 100644 crates/tests/interface_core/tests/ref.rs create mode 100644 crates/tests/interface_core/tests/result.rs diff --git a/crates/libs/bindgen/src/rust/handles.rs b/crates/libs/bindgen/src/rust/handles.rs index 993905d6d9..844fff0dbe 100644 --- a/crates/libs/bindgen/src/rust/handles.rs +++ b/crates/libs/bindgen/src/rust/handles.rs @@ -120,7 +120,7 @@ pub fn gen_win_handle(writer: &Writer, def: metadata::TypeDef) -> TokenStream { dependency.push_str(type_name.name()); tokens.combine("e! { - impl windows_core::CanInto<#dependency> for #ident {} + impl windows_core::imp::CanInto<#dependency> for #ident {} impl From<#ident> for #dependency { fn from(value: #ident) -> Self { Self(value.0) diff --git a/crates/libs/bindgen/src/rust/interfaces.rs b/crates/libs/bindgen/src/rust/interfaces.rs index dfea95062c..0cdb6c74ca 100644 --- a/crates/libs/bindgen/src/rust/interfaces.rs +++ b/crates/libs/bindgen/src/rust/interfaces.rs @@ -126,7 +126,7 @@ fn gen_win_interface(writer: &Writer, def: metadata::TypeDef) -> TokenStream { let cfg = writer.cfg_features(&cfg.union(cfg::type_cfg(writer, ty))); tokens.combine("e! { #cfg - impl<#constraints> windows_core::CanInto<#into> for #ident {} + impl<#constraints> windows_core::imp::CanInto<#into> for #ident {} }); } } @@ -152,7 +152,7 @@ fn gen_win_interface(writer: &Writer, def: metadata::TypeDef) -> TokenStream { let cfg = writer.cfg_features(&cfg.union(cfg::type_cfg(writer, &interface.ty))); tokens.combine("e! { #cfg - impl<#constraints> windows_core::CanInto<#into> for #ident { const QUERY: bool = true; } + impl<#constraints> windows_core::imp::CanInto<#into> for #ident { const QUERY: bool = true; } }); } } diff --git a/crates/libs/core/src/imp/can_into.rs b/crates/libs/core/src/imp/can_into.rs new file mode 100644 index 0000000000..47ee7838c4 --- /dev/null +++ b/crates/libs/core/src/imp/can_into.rs @@ -0,0 +1,5 @@ +pub trait CanInto: Sized { + const QUERY: bool = false; +} + +impl CanInto for T where T: Clone {} diff --git a/crates/libs/core/src/imp/com_bindings.rs b/crates/libs/core/src/imp/com_bindings.rs index 07889c1c64..d999a00463 100644 --- a/crates/libs/core/src/imp/com_bindings.rs +++ b/crates/libs/core/src/imp/com_bindings.rs @@ -428,9 +428,9 @@ impl std::ops::Deref for IReference { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IReference {} -impl windows_core::CanInto for IReference {} -impl windows_core::CanInto for IReference { +impl windows_core::imp::CanInto for IReference {} +impl windows_core::imp::CanInto for IReference {} +impl windows_core::imp::CanInto for IReference { const QUERY: bool = true; } impl IReference { diff --git a/crates/libs/core/src/imp/mod.rs b/crates/libs/core/src/imp/mod.rs index 0dfa52d436..84f27d8f87 100644 --- a/crates/libs/core/src/imp/mod.rs +++ b/crates/libs/core/src/imp/mod.rs @@ -1,4 +1,5 @@ mod bindings; +mod can_into; mod com_bindings; mod delay_load; mod factory_cache; @@ -10,6 +11,7 @@ mod waiter; mod weak_ref_count; pub use bindings::*; +pub use can_into::*; pub use com_bindings::*; pub use delay_load::*; pub use factory_cache::*; @@ -34,7 +36,7 @@ pub fn wide_trim_end(mut wide: &[u16]) -> &[u16] { #[macro_export] macro_rules! interface_hierarchy { ($child:ident, $parent:ty) => { - impl ::windows_core::CanInto<$parent> for $child {} + impl ::windows_core::imp::CanInto<$parent> for $child {} impl ::core::convert::From<&$child> for &$parent { fn from(value: &$child) -> Self { unsafe { ::core::mem::transmute(value) } @@ -59,7 +61,7 @@ pub use interface_hierarchy; #[macro_export] macro_rules! required_hierarchy { ($child:ident, $parent:ty) => { - impl ::windows_core::CanInto<$parent> for $child { const QUERY: bool = true; } + impl ::windows_core::imp::CanInto<$parent> for $child { const QUERY: bool = true; } }; ($child:ident, $first:ty, $($rest:ty),+) => { $crate::imp::required_hierarchy!($child, $first); diff --git a/crates/libs/core/src/lib.rs b/crates/libs/core/src/lib.rs index 44d0ef46d2..997ab2afd5 100644 --- a/crates/libs/core/src/lib.rs +++ b/crates/libs/core/src/lib.rs @@ -19,7 +19,11 @@ mod guid; mod handles; mod inspectable; mod interface; +mod out_param; +mod out_ref; mod param; +mod param_value; +mod r#ref; mod runtime_name; mod runtime_type; mod scoped_interface; @@ -37,7 +41,11 @@ pub use guid::*; pub use handles::*; pub use inspectable::*; pub use interface::*; +pub use out_param::*; +pub use out_ref::*; pub use param::*; +pub use param_value::*; +pub use r#ref::*; pub use r#type::*; pub use runtime_name::*; pub use runtime_type::*; diff --git a/crates/libs/core/src/out_param.rs b/crates/libs/core/src/out_param.rs new file mode 100644 index 0000000000..e39093054a --- /dev/null +++ b/crates/libs/core/src/out_param.rs @@ -0,0 +1,56 @@ +use super::*; + +/// Provides automatic parameter conversion in cases where the Windows API expects implicit conversion support. +/// +/// This is a mutable version of [Param] meant to support out parameters. +/// There is no need to implement this trait. Blanket implementations are provided for all applicable Windows types. +pub trait OutParam::TypeKind>: Sized +where + T: Type, +{ + #[doc(hidden)] + unsafe fn borrow_mut(&self) -> OutRef<'_, T>; +} + +impl OutParam for &mut T +where + T: TypeKind + Clone + Default, +{ + unsafe fn borrow_mut(&self) -> OutRef<'_, T> { + let this: &mut T = std::mem::transmute_copy(self); + std::mem::take(this); + std::mem::transmute_copy(self) + } +} + +impl OutParam for &mut T +where + T: TypeKind + Clone + Default, +{ + unsafe fn borrow_mut(&self) -> OutRef<'_, T> { + std::mem::transmute_copy(self) + } +} + +impl OutParam for &mut Option +where + T: TypeKind + Clone, +{ + unsafe fn borrow_mut(&self) -> OutRef<'_, T> { + let this: &mut Option = std::mem::transmute_copy(self); + std::mem::take(this); + std::mem::transmute_copy(self) + } +} + +impl OutParam for Option<&mut T> +where + T: Type, +{ + unsafe fn borrow_mut(&self) -> OutRef<'_, T> { + match self { + Some(this) => std::mem::transmute_copy(this), + None => std::mem::zeroed(), + } + } +} diff --git a/crates/libs/core/src/out_ref.rs b/crates/libs/core/src/out_ref.rs new file mode 100644 index 0000000000..efd0b09d30 --- /dev/null +++ b/crates/libs/core/src/out_ref.rs @@ -0,0 +1,25 @@ +use super::*; + +/// A borrowed type with the same memory layout as the type itself that can be used to construct ABI-compatible function signatures. +/// +/// This is a mutable version of [Ref] meant to support out parameters. +#[repr(transparent)] +pub struct OutRef<'a, T: Type>(*mut T::Abi, std::marker::PhantomData<&'a T>); + +impl<'a, T: Type> OutRef<'a, T> { + /// Returns `true` if the argument is null. + pub fn is_null(&self) -> bool { + self.0.is_null() + } + + /// Overwrites a memory location with the given value without reading or dropping the old value. + pub fn write(self, value: T::Default) -> Result<()> { + if self.0.is_null() { + Err(Error::from_hresult(imp::E_POINTER)) + } else { + unsafe { *self.0 = std::mem::transmute_copy(&value) } + std::mem::forget(value); + Ok(()) + } + } +} diff --git a/crates/libs/core/src/param.rs b/crates/libs/core/src/param.rs index 81272e1ff1..132a1b6bef 100644 --- a/crates/libs/core/src/param.rs +++ b/crates/libs/core/src/param.rs @@ -28,7 +28,7 @@ where T: TypeKind + Clone, T: Interface, U: Interface, - U: CanInto, + U: imp::CanInto, { unsafe fn param(self) -> ParamValue { if U::QUERY { @@ -52,7 +52,7 @@ impl Param for U where T: TypeKind + Clone, U: TypeKind + Clone, - U: CanInto, + U: imp::CanInto, { unsafe fn param(self) -> ParamValue { ParamValue::Owned(std::mem::transmute_copy(&self)) @@ -82,27 +82,3 @@ impl Param for PSTR { ParamValue::Owned(PCSTR(self.0)) } } - -#[doc(hidden)] -pub enum ParamValue> { - Owned(T), - Borrowed(T::Abi), -} - -impl> ParamValue { - pub fn abi(&self) -> T::Abi { - unsafe { - match self { - Self::Owned(item) => std::mem::transmute_copy(item), - Self::Borrowed(borrowed) => std::mem::transmute_copy(borrowed), - } - } - } -} - -#[doc(hidden)] -pub trait CanInto: Sized { - const QUERY: bool = false; -} - -impl CanInto for T where T: Clone {} diff --git a/crates/libs/core/src/param_value.rs b/crates/libs/core/src/param_value.rs new file mode 100644 index 0000000000..cbe2aba6f2 --- /dev/null +++ b/crates/libs/core/src/param_value.rs @@ -0,0 +1,23 @@ +use super::*; + +#[doc(hidden)] +pub enum ParamValue> { + Owned(T), + Borrowed(T::Abi), +} + +impl> ParamValue { + // TODO: replace with `borrow` in windows-bindgen + pub fn abi(&self) -> T::Abi { + unsafe { + match self { + Self::Owned(item) => std::mem::transmute_copy(item), + Self::Borrowed(borrowed) => std::mem::transmute_copy(borrowed), + } + } + } + + pub fn borrow(&self) -> Ref<'_, T> { + unsafe { std::mem::transmute_copy(&self.abi()) } + } +} diff --git a/crates/libs/core/src/ref.rs b/crates/libs/core/src/ref.rs new file mode 100644 index 0000000000..c96a4c47cc --- /dev/null +++ b/crates/libs/core/src/ref.rs @@ -0,0 +1,12 @@ +use super::*; + +/// A borrowed type with the same memory layout as the type itself that can be used to construct ABI-compatible function signatures. +#[repr(transparent)] +pub struct Ref<'a, T: Type>(T::Abi, std::marker::PhantomData<&'a T>); + +impl<'a, T: Type> std::ops::Deref for Ref<'a, T> { + type Target = T::Default; + fn deref(&self) -> &Self::Target { + unsafe { std::mem::transmute(&self.0) } + } +} diff --git a/crates/libs/interface/Cargo.toml b/crates/libs/interface/Cargo.toml index a5785d9a03..301e5ae230 100644 --- a/crates/libs/interface/Cargo.toml +++ b/crates/libs/interface/Cargo.toml @@ -19,6 +19,6 @@ targets = [] proc-macro = true [dependencies] -syn = { version = "2.0", default-features = false, features = ["parsing", "proc-macro", "printing", "full", "derive"] } +syn = { version = "2.0", default-features = false, features = ["parsing", "proc-macro", "printing", "full", "derive", "clone-impls"] } quote = "1.0" proc-macro2 = "1.0" diff --git a/crates/libs/interface/src/lib.rs b/crates/libs/interface/src/lib.rs index 929889f59f..16cefaa277 100644 --- a/crates/libs/interface/src/lib.rs +++ b/crates/libs/interface/src/lib.rs @@ -125,19 +125,22 @@ impl Interface { let vis = &m.visibility; let name = &m.name; - let args = m.gen_args(); - let params = &m - .args - .iter() - .map(|a| { - let pat = &a.pat; - quote! { #pat } - }) - .collect::>(); + let generics = m.gen_consume_generics(); + let params = m.gen_consume_params(); + let args = m.gen_consume_args(); let ret = &m.ret; - quote! { - #vis unsafe fn #name(&self, #(#args),*) #ret { - (::windows_core::Interface::vtable(self).#name)(::windows_core::Interface::as_raw(self), #(#params),*) + + if m.is_result() { + quote! { + #vis unsafe fn #name<#(#generics),*>(&self, #(#params),*) #ret { + (::windows_core::Interface::vtable(self).#name)(::windows_core::Interface::as_raw(self), #(#args),*).ok() + } + } + } else { + quote! { + #vis unsafe fn #name<#(#generics),*>(&self, #(#params),*) #ret { + (::windows_core::Interface::vtable(self).#name)(::windows_core::Interface::as_raw(self), #(#args),*) + } } } }) @@ -190,8 +193,15 @@ impl Interface { let name = &m.name; let ret = &m.ret; let args = m.gen_args(); - quote! { - pub #name: unsafe extern "system" fn(this: *mut ::core::ffi::c_void, #(#args),*) #ret, + + if m.is_result() { + quote! { + pub #name: unsafe extern "system" fn(this: *mut ::core::ffi::c_void, #(#args),*) -> ::windows_core::HRESULT, + } + } else { + quote! { + pub #name: unsafe extern "system" fn(this: *mut ::core::ffi::c_void, #(#args),*) #ret, + } } }) .collect::>(); @@ -214,6 +224,13 @@ impl Interface { }) .collect::>(); let ret = &m.ret; + + let ret = if m.is_result() { + quote! { -> ::windows_core::HRESULT } + } else { + quote! { #ret } + }; + if parent_vtable.is_some() { quote! { unsafe extern "system" fn #name, Impl: #trait_name, const OFFSET: isize>(this: *mut ::core::ffi::c_void, #(#args),*) #ret { @@ -505,6 +522,25 @@ struct InterfaceMethod { } impl InterfaceMethod { + fn is_result(&self) -> bool { + if let syn::ReturnType::Type(_, ty) = &self.ret { + if let syn::Type::Path(path) = &**ty { + if let Some(segment) = path.path.segments.last() { + let ident = segment.ident.to_string(); + if ident == "Result" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if args.args.len() == 1 { + return true; + } + } + } + } + } + } + + false + } + /// Generates arguments (of the form `$pat: $type`) fn gen_args(&self) -> Vec { self.args @@ -516,6 +552,62 @@ impl InterfaceMethod { }) .collect::>() } + + fn gen_consume_generics(&self) -> Vec { + self.args + .iter() + .enumerate() + .filter_map(|(generic_index, a)| { + if let Some((ty, ident)) = a.borrow_type() { + let generic_ident = quote::format_ident!("P{generic_index}"); + if ident == "Ref" { + Some(quote! { #generic_ident: ::windows_core::Param<#ty> }) + } else { + Some(quote! { #generic_ident: ::windows_core::OutParam<#ty> }) + } + } else { + None + } + }) + .collect::>() + } + + fn gen_consume_params(&self) -> Vec { + self.args + .iter() + .enumerate() + .map(|(generic_index, a)| { + let pat = &a.pat; + + if a.borrow_type().is_some() { + let generic_ident = quote::format_ident!("P{generic_index}"); + quote! { #pat: #generic_ident } + } else { + let ty = &a.ty; + quote! { #pat: #ty } + } + }) + .collect::>() + } + + fn gen_consume_args(&self) -> Vec { + self.args + .iter() + .map(|a| { + let pat = &a.pat; + + if let Some((_, ident)) = a.borrow_type() { + if ident == "Ref" { + quote! { #pat.param().borrow() } + } else { + quote! { #pat.borrow_mut() } + } + } else { + quote! { #pat } + } + }) + .collect::>() + } } impl syn::parse::Parse for InterfaceMethod { @@ -554,3 +646,24 @@ struct InterfaceMethodArg { /// The name of the argument pub pat: Box, } + +impl InterfaceMethodArg { + fn borrow_type(&self) -> Option<(syn::Type, String)> { + if let syn::Type::Path(path) = &*self.ty { + if let Some(segment) = path.path.segments.last() { + let ident = segment.ident.to_string(); + if matches!(ident.as_str(), "Ref" | "OutRef") { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if args.args.len() == 1 { + if let Some(syn::GenericArgument::Type(ty)) = args.args.first() { + return Some((ty.clone(), ident)); + } + } + } + } + } + } + + None + } +} diff --git a/crates/libs/windows/src/Windows/Foundation/Collections/mod.rs b/crates/libs/windows/src/Windows/Foundation/Collections/mod.rs index 53bea775cd..da1a6264c6 100644 --- a/crates/libs/windows/src/Windows/Foundation/Collections/mod.rs +++ b/crates/libs/windows/src/Windows/Foundation/Collections/mod.rs @@ -9,8 +9,8 @@ impl std::ops::Deref for IIterable { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IIterable {} -impl windows_core::CanInto for IIterable {} +impl windows_core::imp::CanInto for IIterable {} +impl windows_core::imp::CanInto for IIterable {} impl IIterable { pub fn First(&self) -> windows_core::Result> { let this = self; @@ -61,8 +61,8 @@ impl std::ops::Deref for IIterator { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IIterator {} -impl windows_core::CanInto for IIterator {} +impl windows_core::imp::CanInto for IIterator {} +impl windows_core::imp::CanInto for IIterator {} impl IIterator { pub fn Current(&self) -> windows_core::Result { let this = self; @@ -134,8 +134,8 @@ impl windows_core::CanInto for IKeyValuePair {} -impl windows_core::CanInto for IKeyValuePair {} +impl windows_core::imp::CanInto for IKeyValuePair {} +impl windows_core::imp::CanInto for IKeyValuePair {} impl IKeyValuePair { pub fn Key(&self) -> windows_core::Result { let this = self; @@ -183,9 +183,9 @@ impl windows_core::CanInto for IMap {} -impl windows_core::CanInto for IMap {} -impl windows_core::CanInto>> for IMap { +impl windows_core::imp::CanInto for IMap {} +impl windows_core::imp::CanInto for IMap {} +impl windows_core::imp::CanInto>> for IMap { const QUERY: bool = true; } impl IMap { @@ -302,8 +302,8 @@ impl std::ops::Deref for IMapChangedEven unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IMapChangedEventArgs {} -impl windows_core::CanInto for IMapChangedEventArgs {} +impl windows_core::imp::CanInto for IMapChangedEventArgs {} +impl windows_core::imp::CanInto for IMapChangedEventArgs {} impl IMapChangedEventArgs { pub fn CollectionChange(&self) -> windows_core::Result { let this = self; @@ -349,9 +349,9 @@ impl windows_core::CanInto for IMapView {} -impl windows_core::CanInto for IMapView {} -impl windows_core::CanInto>> for IMapView { +impl windows_core::imp::CanInto for IMapView {} +impl windows_core::imp::CanInto for IMapView {} +impl windows_core::imp::CanInto>> for IMapView { const QUERY: bool = true; } impl IMapView { @@ -441,12 +441,12 @@ impl windows_core::CanInto for IObservableMap {} -impl windows_core::CanInto for IObservableMap {} -impl windows_core::CanInto>> for IObservableMap { +impl windows_core::imp::CanInto for IObservableMap {} +impl windows_core::imp::CanInto for IObservableMap {} +impl windows_core::imp::CanInto>> for IObservableMap { const QUERY: bool = true; } -impl windows_core::CanInto> for IObservableMap { +impl windows_core::imp::CanInto> for IObservableMap { const QUERY: bool = true; } impl IObservableMap { @@ -572,12 +572,12 @@ impl std::ops::Deref for IObservableVect unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IObservableVector {} -impl windows_core::CanInto for IObservableVector {} -impl windows_core::CanInto> for IObservableVector { +impl windows_core::imp::CanInto for IObservableVector {} +impl windows_core::imp::CanInto for IObservableVector {} +impl windows_core::imp::CanInto> for IObservableVector { const QUERY: bool = true; } -impl windows_core::CanInto> for IObservableVector { +impl windows_core::imp::CanInto> for IObservableVector { const QUERY: bool = true; } impl IObservableVector { @@ -819,9 +819,9 @@ impl std::ops::Deref for IVector { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IVector {} -impl windows_core::CanInto for IVector {} -impl windows_core::CanInto> for IVector { +impl windows_core::imp::CanInto for IVector {} +impl windows_core::imp::CanInto for IVector {} +impl windows_core::imp::CanInto> for IVector { const QUERY: bool = true; } impl IVector { @@ -1011,9 +1011,9 @@ impl std::ops::Deref for IVectorView unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IVectorView {} -impl windows_core::CanInto for IVectorView {} -impl windows_core::CanInto> for IVectorView { +impl windows_core::imp::CanInto for IVectorView {} +impl windows_core::imp::CanInto for IVectorView {} +impl windows_core::imp::CanInto> for IVectorView { const QUERY: bool = true; } impl IVectorView { diff --git a/crates/libs/windows/src/Windows/Foundation/mod.rs b/crates/libs/windows/src/Windows/Foundation/mod.rs index 88bb44a88f..c2d18ce7ae 100644 --- a/crates/libs/windows/src/Windows/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Foundation/mod.rs @@ -116,9 +116,9 @@ impl std::ops::Deref for IAsyncA unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IAsyncActionWithProgress {} -impl windows_core::CanInto for IAsyncActionWithProgress {} -impl windows_core::CanInto for IAsyncActionWithProgress { +impl windows_core::imp::CanInto for IAsyncActionWithProgress {} +impl windows_core::imp::CanInto for IAsyncActionWithProgress {} +impl windows_core::imp::CanInto for IAsyncActionWithProgress { const QUERY: bool = true; } impl IAsyncActionWithProgress { @@ -297,9 +297,9 @@ impl std::ops::Deref for IAsyncOpe unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IAsyncOperation {} -impl windows_core::CanInto for IAsyncOperation {} -impl windows_core::CanInto for IAsyncOperation { +impl windows_core::imp::CanInto for IAsyncOperation {} +impl windows_core::imp::CanInto for IAsyncOperation {} +impl windows_core::imp::CanInto for IAsyncOperation { const QUERY: bool = true; } impl IAsyncOperation { @@ -415,9 +415,9 @@ impl windows_core::CanInto for IAsyncOperationWithProgress {} -impl windows_core::CanInto for IAsyncOperationWithProgress {} -impl windows_core::CanInto for IAsyncOperationWithProgress { +impl windows_core::imp::CanInto for IAsyncOperationWithProgress {} +impl windows_core::imp::CanInto for IAsyncOperationWithProgress {} +impl windows_core::imp::CanInto for IAsyncOperationWithProgress { const QUERY: bool = true; } impl IAsyncOperationWithProgress { @@ -1030,9 +1030,9 @@ impl std::ops::Deref for IReference { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IReference {} -impl windows_core::CanInto for IReference {} -impl windows_core::CanInto for IReference { +impl windows_core::imp::CanInto for IReference {} +impl windows_core::imp::CanInto for IReference {} +impl windows_core::imp::CanInto for IReference { const QUERY: bool = true; } impl IReference { @@ -1287,9 +1287,9 @@ impl std::ops::Deref for IReferenceArray unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto for IReferenceArray {} -impl windows_core::CanInto for IReferenceArray {} -impl windows_core::CanInto for IReferenceArray { +impl windows_core::imp::CanInto for IReferenceArray {} +impl windows_core::imp::CanInto for IReferenceArray {} +impl windows_core::imp::CanInto for IReferenceArray { const QUERY: bool = true; } impl IReferenceArray { diff --git a/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs b/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs index 361e63377a..93e07590ef 100644 --- a/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs @@ -10824,7 +10824,7 @@ impl Default for HINSTANCE { impl windows_core::TypeKind for HINSTANCE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HINSTANCE {} +impl windows_core::imp::CanInto for HINSTANCE {} impl From for HMODULE { fn from(value: HINSTANCE) -> Self { Self(value.0) @@ -10887,7 +10887,7 @@ impl Default for HMODULE { impl windows_core::TypeKind for HMODULE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HMODULE {} +impl windows_core::imp::CanInto for HMODULE {} impl From for HINSTANCE { fn from(value: HMODULE) -> Self { Self(value.0) @@ -10953,7 +10953,7 @@ impl Default for HWND { impl windows_core::TypeKind for HWND { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HWND {} +impl windows_core::imp::CanInto for HWND {} impl From for HANDLE { fn from(value: HWND) -> Self { Self(value.0) diff --git a/crates/libs/windows/src/Windows/Win32/Graphics/Gdi/mod.rs b/crates/libs/windows/src/Windows/Win32/Graphics/Gdi/mod.rs index b3659db718..0e6c7c90b3 100644 --- a/crates/libs/windows/src/Windows/Win32/Graphics/Gdi/mod.rs +++ b/crates/libs/windows/src/Windows/Win32/Graphics/Gdi/mod.rs @@ -10683,7 +10683,7 @@ impl Default for HBITMAP { impl windows_core::TypeKind for HBITMAP { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HBITMAP {} +impl windows_core::imp::CanInto for HBITMAP {} impl From for HGDIOBJ { fn from(value: HBITMAP) -> Self { Self(value.0) @@ -10712,7 +10712,7 @@ impl Default for HBRUSH { impl windows_core::TypeKind for HBRUSH { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HBRUSH {} +impl windows_core::imp::CanInto for HBRUSH {} impl From for HGDIOBJ { fn from(value: HBRUSH) -> Self { Self(value.0) @@ -10780,7 +10780,7 @@ impl Default for HFONT { impl windows_core::TypeKind for HFONT { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HFONT {} +impl windows_core::imp::CanInto for HFONT {} impl From for HGDIOBJ { fn from(value: HFONT) -> Self { Self(value.0) @@ -10864,7 +10864,7 @@ impl Default for HPALETTE { impl windows_core::TypeKind for HPALETTE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HPALETTE {} +impl windows_core::imp::CanInto for HPALETTE {} impl From for HGDIOBJ { fn from(value: HPALETTE) -> Self { Self(value.0) @@ -10893,7 +10893,7 @@ impl Default for HPEN { impl windows_core::TypeKind for HPEN { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HPEN {} +impl windows_core::imp::CanInto for HPEN {} impl From for HGDIOBJ { fn from(value: HPEN) -> Self { Self(value.0) @@ -10922,7 +10922,7 @@ impl Default for HRGN { impl windows_core::TypeKind for HRGN { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HRGN {} +impl windows_core::imp::CanInto for HRGN {} impl From for HGDIOBJ { fn from(value: HRGN) -> Self { Self(value.0) diff --git a/crates/libs/windows/src/Windows/Win32/Security/Cryptography/mod.rs b/crates/libs/windows/src/Windows/Win32/Security/Cryptography/mod.rs index 915d1975ea..b280e86bb5 100644 --- a/crates/libs/windows/src/Windows/Win32/Security/Cryptography/mod.rs +++ b/crates/libs/windows/src/Windows/Win32/Security/Cryptography/mod.rs @@ -8814,7 +8814,7 @@ impl Default for BCRYPT_ALG_HANDLE { impl windows_core::TypeKind for BCRYPT_ALG_HANDLE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for BCRYPT_ALG_HANDLE {} +impl windows_core::imp::CanInto for BCRYPT_ALG_HANDLE {} impl From for BCRYPT_HANDLE { fn from(value: BCRYPT_ALG_HANDLE) -> Self { Self(value.0) @@ -9209,7 +9209,7 @@ impl Default for BCRYPT_HASH_HANDLE { impl windows_core::TypeKind for BCRYPT_HASH_HANDLE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for BCRYPT_HASH_HANDLE {} +impl windows_core::imp::CanInto for BCRYPT_HASH_HANDLE {} impl From for BCRYPT_HANDLE { fn from(value: BCRYPT_HASH_HANDLE) -> Self { Self(value.0) @@ -9328,7 +9328,7 @@ impl Default for BCRYPT_KEY_HANDLE { impl windows_core::TypeKind for BCRYPT_KEY_HANDLE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for BCRYPT_KEY_HANDLE {} +impl windows_core::imp::CanInto for BCRYPT_KEY_HANDLE {} impl From for BCRYPT_HANDLE { fn from(value: BCRYPT_KEY_HANDLE) -> Self { Self(value.0) @@ -9663,7 +9663,7 @@ impl Default for BCRYPT_SECRET_HANDLE { impl windows_core::TypeKind for BCRYPT_SECRET_HANDLE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for BCRYPT_SECRET_HANDLE {} +impl windows_core::imp::CanInto for BCRYPT_SECRET_HANDLE {} impl From for BCRYPT_HANDLE { fn from(value: BCRYPT_SECRET_HANDLE) -> Self { Self(value.0) @@ -19742,7 +19742,7 @@ impl Default for NCRYPT_KEY_HANDLE { impl windows_core::TypeKind for NCRYPT_KEY_HANDLE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for NCRYPT_KEY_HANDLE {} +impl windows_core::imp::CanInto for NCRYPT_KEY_HANDLE {} impl From for NCRYPT_HANDLE { fn from(value: NCRYPT_KEY_HANDLE) -> Self { Self(value.0) @@ -19978,7 +19978,7 @@ impl Default for NCRYPT_PROV_HANDLE { impl windows_core::TypeKind for NCRYPT_PROV_HANDLE { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for NCRYPT_PROV_HANDLE {} +impl windows_core::imp::CanInto for NCRYPT_PROV_HANDLE {} impl From for NCRYPT_HANDLE { fn from(value: NCRYPT_PROV_HANDLE) -> Self { Self(value.0) diff --git a/crates/libs/windows/src/Windows/Win32/UI/WindowsAndMessaging/mod.rs b/crates/libs/windows/src/Windows/Win32/UI/WindowsAndMessaging/mod.rs index f88e68af11..ed4fc0d86d 100644 --- a/crates/libs/windows/src/Windows/Win32/UI/WindowsAndMessaging/mod.rs +++ b/crates/libs/windows/src/Windows/Win32/UI/WindowsAndMessaging/mod.rs @@ -8882,7 +8882,7 @@ impl Default for HCURSOR { impl windows_core::TypeKind for HCURSOR { type TypeKind = windows_core::CopyType; } -impl windows_core::CanInto for HCURSOR {} +impl windows_core::imp::CanInto for HCURSOR {} impl From for HICON { fn from(value: HCURSOR) -> Self { Self(value.0) diff --git a/crates/samples/components/json_validator_winrt/src/lib.rs b/crates/samples/components/json_validator_winrt/src/lib.rs index b816b0d478..7c39d6e024 100644 --- a/crates/samples/components/json_validator_winrt/src/lib.rs +++ b/crates/samples/components/json_validator_winrt/src/lib.rs @@ -64,8 +64,8 @@ fn json_from_hstring(value: &HSTRING) -> Result { } #[no_mangle] -extern "system" fn DllGetActivationFactory( - name: std::mem::ManuallyDrop, +unsafe extern "system" fn DllGetActivationFactory( + name: Ref, result: *mut *mut std::ffi::c_void, ) -> HRESULT { if result.is_null() { diff --git a/crates/tests/component/src/lib.rs b/crates/tests/component/src/lib.rs index 08785be5c7..d53dbdb1b8 100644 --- a/crates/tests/component/src/lib.rs +++ b/crates/tests/component/src/lib.rs @@ -68,21 +68,16 @@ impl IActivationFactory_Impl for ClassFactory { } } +// HRESULT __stdcall DllGetActivationFactory(HSTRING, IActivationFactory**) #[no_mangle] unsafe extern "system" fn DllGetActivationFactory( - name: std::mem::ManuallyDrop, - result: *mut *mut std::ffi::c_void, + name: Ref, + factory: OutRef, ) -> HRESULT { - let factory: Option = match (*name).to_string().as_str() { - "test_component.Class" => Some(ClassFactory.into()), - _ => None, - }; - - if let Some(factory) = factory { - *result = factory.into_raw(); - S_OK + if *name == "test_component.Class" { + factory.write(Some(ClassFactory.into())).into() } else { - *result = std::ptr::null_mut(); + _ = factory.write(None); CLASS_E_CLASSNOTAVAILABLE } } diff --git a/crates/tests/interface_core/tests/ref.rs b/crates/tests/interface_core/tests/ref.rs new file mode 100644 index 0000000000..e103dc2351 --- /dev/null +++ b/crates/tests/interface_core/tests/ref.rs @@ -0,0 +1,133 @@ +#![allow(non_snake_case)] + +use windows_core::*; + +pub const S_OK: HRESULT = HRESULT(0); +pub const S_FALSE: HRESULT = HRESULT(1); +pub const E_INVALIDARG: HRESULT = HRESULT(0x80070057_u32 as _); +pub const E_POINTER: HRESULT = HRESULT(0x80004003_u32 as _); + +#[interface("09428a59-5b40-4e4c-9175-e7a78514316d")] +unsafe trait ITest: IUnknown { + // TODO: compile error if param type is not Ref/OutRef and is not Copy + + unsafe fn usize(&self, input: usize, output: OutRef) -> HRESULT; + unsafe fn hstring(&self, input: Ref, output: OutRef) -> HRESULT; + unsafe fn interface(&self, input: Ref, output: OutRef) -> HRESULT; + unsafe fn required_input(&self, input: Ref, output: OutRef) -> HRESULT; + unsafe fn optional_output(&self, input: Ref, output: OutRef) -> HRESULT; + + unsafe fn result_usize(&self, input: usize, output: OutRef) -> Result<()>; + unsafe fn result_hstring(&self, input: Ref, output: OutRef) -> Result<()>; + unsafe fn result_interface(&self, input: Ref, output: OutRef) -> Result<()>; + unsafe fn result_required_input(&self, input: Ref, output: OutRef) -> Result<()>; +} + +#[implement(ITest)] +struct Test; + +impl ITest_Impl for Test { + unsafe fn usize(&self, input: usize, output: OutRef) -> HRESULT { + output.write(input).into() + } + unsafe fn hstring(&self, input: Ref, output: OutRef) -> HRESULT { + output.write(input.clone()).into() + } + unsafe fn interface(&self, input: Ref, output: OutRef) -> HRESULT { + output.write(input.clone()).into() + } + unsafe fn required_input(&self, input: Ref, output: OutRef) -> HRESULT { + if input.is_none() { + E_INVALIDARG + } else { + self.interface(input, output) + } + } + + unsafe fn optional_output(&self, input: Ref, output: OutRef) -> HRESULT { + if output.is_null() { + S_FALSE + } else { + self.interface(input, output) + } + } + + unsafe fn result_usize(&self, input: usize, output: OutRef) -> Result<()> { + output.write(input) + } + unsafe fn result_hstring(&self, input: Ref, output: OutRef) -> Result<()> { + output.write(input.clone()) + } + unsafe fn result_interface(&self, input: Ref, output: OutRef) -> Result<()> { + output.write(input.clone()) + } + unsafe fn result_required_input(&self, input: Ref, output: OutRef) -> Result<()> { + if input.is_none() { + E_INVALIDARG.ok() + } else { + self.result_interface(input, output) + } + } +} + +#[test] +fn test() { + unsafe { + let test: ITest = Test.into(); + + assert_eq!(test.usize(0, None), E_POINTER); + assert_eq!(test.hstring(h!("hello"), None), E_POINTER); + assert_eq!(test.interface(None, None), E_POINTER); + assert_eq!(test.required_input(None, None), E_INVALIDARG); + + let mut output = 0; + assert_eq!(test.usize(123, &mut output), S_OK); + assert_eq!(output, 123); + + let mut output = HSTRING::from("will be dropped"); + // `output` will be dropped before receiving value, avoiding a leak. + assert_eq!(test.hstring(h!("hello"), &mut output), S_OK); + assert_eq!(&output, h!("hello")); + + let mut output = None; + assert_eq!(test.interface(&test, &mut output), S_OK); + assert_eq!(output.as_ref(), Some(&test)); + + // `output` will be dropped before receiving next value, avoiding a leak. + assert_eq!(test.required_input(&test, &mut output), S_OK); + assert_eq!(output.as_ref(), Some(&test)); + + assert_eq!(test.optional_output(&test, None), S_FALSE); + assert_eq!(test.optional_output(&test, &mut output), S_OK); + assert_eq!(output, Some(test)); + } +} + +#[test] +fn test_result() { + unsafe { + let test: ITest = Test.into(); + + assert_eq!(test.result_usize(0, None), E_POINTER.ok()); + assert_eq!(test.result_hstring(h!("hello"), None), E_POINTER.ok()); + assert_eq!(test.result_interface(None, None), E_POINTER.ok()); + assert_eq!(test.result_required_input(None, None), E_INVALIDARG.ok()); + + let mut output = 0; + assert_eq!(test.result_usize(123, &mut output), Ok(())); + assert_eq!(output, 123); + + let mut output = HSTRING::from("will be dropped"); + // `output` will be dropped before receiving value, avoiding a leak. + assert_eq!(test.result_hstring(h!("hello"), &mut output), Ok(())); + assert_eq!(&output, h!("hello")); + + let mut output = None; + assert_eq!(test.result_interface(&test, &mut output), Ok(())); + assert_eq!(output.as_ref(), Some(&test)); + + // `output` will be dropped before receiving next value, avoiding a leak. + assert_eq!(test.result_required_input(&test, &mut output), Ok(())); + assert_eq!(output, Some(test)); + } +} diff --git a/crates/tests/interface_core/tests/result.rs b/crates/tests/interface_core/tests/result.rs new file mode 100644 index 0000000000..c5d8b8edc1 --- /dev/null +++ b/crates/tests/interface_core/tests/result.rs @@ -0,0 +1,44 @@ +#![allow(non_snake_case)] + +use windows_core::*; + +pub const S_OK: HRESULT = HRESULT(0); +pub const S_FALSE: HRESULT = HRESULT(1); +pub const E_INVALIDARG: HRESULT = HRESULT(0x80070057_u32 as _); + +#[interface("09428a59-5b40-4e4c-9175-e7a78514316d")] +unsafe trait ITest: IUnknown { + unsafe fn Void(&self); + unsafe fn Code(&self, code: HRESULT) -> HRESULT; + unsafe fn Result(&self, code: HRESULT) -> Result<()>; +} + +#[implement(ITest)] +struct Test; + +impl ITest_Impl for Test { + unsafe fn Void(&self) {} + unsafe fn Code(&self, code: HRESULT) -> HRESULT { + code + } + unsafe fn Result(&self, code: HRESULT) -> Result<()> { + code.ok() + } +} + +#[test] +fn test() { + unsafe { + let test: ITest = Test.into(); + + test.Void(); + + assert_eq!(test.Code(S_OK), S_OK); + assert_eq!(test.Code(S_FALSE), S_FALSE); + assert_eq!(test.Code(E_INVALIDARG), E_INVALIDARG); + + assert!(test.Result(S_OK).is_ok()); + assert!(test.Result(S_FALSE).is_ok()); + assert_eq!(test.Result(E_INVALIDARG), E_INVALIDARG.ok()); + } +} diff --git a/crates/tests/riddle/src/generic_interfaces.rs b/crates/tests/riddle/src/generic_interfaces.rs index 44bc22ee74..c177c06b74 100644 --- a/crates/tests/riddle/src/generic_interfaces.rs +++ b/crates/tests/riddle/src/generic_interfaces.rs @@ -16,11 +16,11 @@ impl std::ops::Deref for IIterable { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterable { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterable { } @@ -75,11 +75,11 @@ impl std::ops::Deref for IIterator { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterator { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterator { } @@ -168,11 +168,11 @@ impl - windows_core::CanInto for IKeyValuePair + windows_core::imp::CanInto for IKeyValuePair { } impl - windows_core::CanInto for IKeyValuePair + windows_core::imp::CanInto for IKeyValuePair { } impl @@ -259,15 +259,15 @@ impl - windows_core::CanInto for IMapView + windows_core::imp::CanInto for IMapView { } impl - windows_core::CanInto for IMapView + windows_core::imp::CanInto for IMapView { } impl - windows_core::CanInto>> for IMapView + windows_core::imp::CanInto>> for IMapView { const QUERY: bool = true; } diff --git a/crates/tests/standalone/src/b_calendar.rs b/crates/tests/standalone/src/b_calendar.rs index 51fd0f9079..1dce3a9d37 100644 --- a/crates/tests/standalone/src/b_calendar.rs +++ b/crates/tests/standalone/src/b_calendar.rs @@ -1653,11 +1653,11 @@ impl std::ops::Deref for IIterable { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterable { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterable { } @@ -1726,11 +1726,11 @@ impl std::ops::Deref for IIterator { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterator { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterator { } @@ -1873,15 +1873,15 @@ impl std::ops::Deref for IVectorView unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IVectorView { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IVectorView { } -impl windows_core::CanInto> +impl windows_core::imp::CanInto> for IVectorView { const QUERY: bool = true; diff --git a/crates/tests/standalone/src/b_uri.rs b/crates/tests/standalone/src/b_uri.rs index 4de4eed61c..2b7f7863e8 100644 --- a/crates/tests/standalone/src/b_uri.rs +++ b/crates/tests/standalone/src/b_uri.rs @@ -16,11 +16,11 @@ impl std::ops::Deref for IIterable { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterable { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterable { } @@ -89,11 +89,11 @@ impl std::ops::Deref for IIterator { unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterator { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IIterator { } @@ -393,15 +393,15 @@ impl std::ops::Deref for IVectorView unsafe { std::mem::transmute(self) } } } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IVectorView { } -impl windows_core::CanInto +impl windows_core::imp::CanInto for IVectorView { } -impl windows_core::CanInto> +impl windows_core::imp::CanInto> for IVectorView { const QUERY: bool = true;