From a6b49e94f4c6c1781d58e0c3496c8302a3533892 Mon Sep 17 00:00:00 2001 From: Kenny Kerr Date: Thu, 1 Aug 2024 09:52:11 -0700 Subject: [PATCH] Add precise registry types and allocation-free queries and updates (#3184) --- crates/libs/registry/src/data.rs | 129 +++++++ crates/libs/registry/src/key.rs | 340 ++++++------------ crates/libs/registry/src/key_iterator.rs | 1 - crates/libs/registry/src/lib.rs | 33 +- crates/libs/registry/src/pcwstr.rs | 41 +++ crates/libs/registry/src/type.rs | 43 ++- crates/libs/registry/src/value.rs | 126 ++++++- crates/libs/registry/src/value_iterator.rs | 99 +---- crates/libs/strings/src/hstring.rs | 5 + crates/libs/strings/src/hstring_builder.rs | 11 + crates/libs/strings/src/pcwstr.rs | 6 + crates/tests/registry/tests/bad_string.rs | 24 +- crates/tests/registry/tests/bytes.rs | 24 ++ .../tests/{sys_interop.rs => interop.rs} | 45 ++- crates/tests/registry/tests/raw.rs | 35 ++ crates/tests/registry/tests/string.rs | 35 ++ crates/tests/registry/tests/u32.rs | 15 + crates/tests/registry/tests/u64.rs | 15 + crates/tests/registry/tests/value.rs | 45 +++ crates/tests/registry/tests/values.rs | 52 +-- .../tests/registry/tests/windows_interop.rs | 40 --- 21 files changed, 714 insertions(+), 450 deletions(-) create mode 100644 crates/libs/registry/src/data.rs create mode 100644 crates/libs/registry/src/pcwstr.rs create mode 100644 crates/tests/registry/tests/bytes.rs rename crates/tests/registry/tests/{sys_interop.rs => interop.rs} (50%) create mode 100644 crates/tests/registry/tests/raw.rs create mode 100644 crates/tests/registry/tests/string.rs create mode 100644 crates/tests/registry/tests/u32.rs create mode 100644 crates/tests/registry/tests/u64.rs create mode 100644 crates/tests/registry/tests/value.rs delete mode 100644 crates/tests/registry/tests/windows_interop.rs diff --git a/crates/libs/registry/src/data.rs b/crates/libs/registry/src/data.rs new file mode 100644 index 0000000000..fa1660f391 --- /dev/null +++ b/crates/libs/registry/src/data.rs @@ -0,0 +1,129 @@ +use super::*; + +// Minimal `Vec` replacement providing at least `u16` alignment so that it can be used for wide strings. +pub struct Data { + ptr: *mut u8, + len: usize, +} + +impl Data { + // Creates a buffer with the specified length of zero bytes. + pub fn new(len: usize) -> Result { + unsafe { + let bytes = Self::alloc(len)?; + + if len > 0 { + core::ptr::write_bytes(bytes.ptr, 0, len); + } + + Ok(bytes) + } + } + + // Returns the buffer as a slice of u16 for reading wide characters. The slice trims off any trailing zero bytes. + pub fn as_wide(&self) -> &[u16] { + if self.ptr.is_null() { + &[] + } else { + let mut wide = + unsafe { core::slice::from_raw_parts(self.ptr as *const u16, self.len / 2) }; + + while wide.last() == Some(&0) { + wide = &wide[..wide.len() - 1]; + } + + wide + } + } + + // Creates a buffer by copying the bytes from the slice. + pub fn from_slice(slice: &[u8]) -> Result { + unsafe { + let bytes = Self::alloc(slice.len())?; + + if !slice.is_empty() { + core::ptr::copy_nonoverlapping(slice.as_ptr(), bytes.ptr, slice.len()); + } + + Ok(bytes) + } + } + + // Allocates an uninitialized buffer. + unsafe fn alloc(len: usize) -> Result { + if len == 0 { + Ok(Self { + ptr: null_mut(), + len: 0, + }) + } else { + // This pointer will have at least 8 byte alignment. + let ptr = HeapAlloc(GetProcessHeap(), 0, len) as *mut u8; + + if ptr.is_null() { + Err(Error::from_hresult(HRESULT::from_win32(ERROR_OUTOFMEMORY))) + } else { + Ok(Self { ptr, len }) + } + } + } +} + +impl Drop for Data { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { + HeapFree(GetProcessHeap(), 0, self.ptr as *mut _); + } + } + } +} + +impl Deref for Data { + type Target = [u8]; + + fn deref(&self) -> &[u8] { + if self.ptr.is_null() { + &[] + } else { + unsafe { core::slice::from_raw_parts(self.ptr, self.len) } + } + } +} + +impl core::ops::DerefMut for Data { + fn deref_mut(&mut self) -> &mut [u8] { + if self.ptr.is_null() { + &mut [] + } else { + unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) } + } + } +} + +impl Clone for Data { + fn clone(&self) -> Self { + Self::from_slice(self).unwrap() + } +} + +impl PartialEq for Data { + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl Eq for Data {} + +impl core::fmt::Debug for Data { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.deref().fmt(f) + } +} + +impl TryFrom<[u8; N]> for Data { + type Error = Error; + fn try_from(from: [u8; N]) -> Result { + Self::from_slice(&from) + } +} diff --git a/crates/libs/registry/src/key.rs b/crates/libs/registry/src/key.rs index 59eb84f211..f6b5191951 100644 --- a/crates/libs/registry/src/key.rs +++ b/crates/libs/registry/src/key.rs @@ -1,5 +1,4 @@ use super::*; -use core::ptr::{null, null_mut}; /// A registry key. #[repr(transparent)] @@ -8,14 +7,14 @@ pub struct Key(pub(crate) HKEY); impl Default for Key { fn default() -> Self { - Self(core::ptr::null_mut()) + Self(null_mut()) } } impl Key { /// Creates a registry key. If the key already exists, the function opens it. pub fn create>(&self, path: T) -> Result { - let mut handle = core::ptr::null_mut(); + let mut handle = null_mut(); let result = unsafe { RegCreateKeyExW( @@ -36,7 +35,7 @@ impl Key { /// Opens a registry key. pub fn open>(&self, path: T) -> Result { - let mut handle = core::ptr::null_mut(); + let mut handle = null_mut(); let result = unsafe { RegOpenKeyExW(self.0, pcwstr(path).as_ptr(), 0, KEY_READ, &mut handle) }; @@ -83,19 +82,17 @@ impl Key { /// Sets the name and value in the registry key. pub fn set_u32>(&self, name: T, value: u32) -> Result<()> { - unsafe { self.set_value(name, REG_DWORD, &value as *const _ as _, 4) } + self.set_bytes(name, Type::U32, &value.to_le_bytes()) } /// Sets the name and value in the registry key. pub fn set_u64>(&self, name: T, value: u64) -> Result<()> { - unsafe { self.set_value(name, REG_QWORD, &value as *const _ as _, 8) } + self.set_bytes(name, Type::U64, &value.to_le_bytes()) } /// Sets the name and value in the registry key. pub fn set_string>(&self, name: T, value: T) -> Result<()> { - let value = pcwstr(value); - - unsafe { self.set_value(name, REG_SZ, value.as_ptr() as _, value.len() * 2) } + self.set_bytes(name, Type::String, pcwstr(value).as_bytes()) } /// Sets the name and value in the registry key. @@ -104,275 +101,166 @@ impl Key { name: T, value: &windows_strings::HSTRING, ) -> Result<()> { - unsafe { self.set_value(name, REG_SZ, value.as_ptr() as _, value.len() * 2) } + self.set_bytes(name, Type::String, value.as_bytes()) } /// Sets the name and value in the registry key. - pub fn set_multi_string>(&self, name: T, value: &[T]) -> Result<()> { - let mut packed = value.iter().fold(vec![0u16; 0], |mut packed, value| { - packed.append(&mut pcwstr(value)); - packed - }); + pub fn set_expand_string>(&self, name: T, value: T) -> Result<()> { + self.set_bytes(name, Type::ExpandString, pcwstr(value).as_bytes()) + } - packed.push(0); + /// Sets the name and value in the registry key. + pub fn set_expand_hstring>( + &self, + name: T, + value: &windows_strings::HSTRING, + ) -> Result<()> { + self.set_bytes(name, Type::ExpandString, value.as_bytes()) + } + + /// Sets the name and value in the registry key. + pub fn set_multi_string>(&self, name: T, value: &[T]) -> Result<()> { + let value = multi_pcwstr(value); + self.set_bytes(name, Type::MultiString, value.as_bytes()) + } - unsafe { self.set_value(name, REG_MULTI_SZ, packed.as_ptr() as _, packed.len() * 2) } + /// Sets the name and value in the registry key. + pub fn set_value>(&self, name: T, value: &Value) -> Result<()> { + self.set_bytes(name, value.ty(), value) } /// Sets the name and value in the registry key. - pub fn set_bytes>(&self, name: T, value: &[u8]) -> Result<()> { - unsafe { self.set_value(name, REG_BINARY, value.as_ptr() as _, value.len()) } + pub fn set_bytes>(&self, name: T, ty: Type, value: &[u8]) -> Result<()> { + unsafe { self.raw_set_bytes(pcwstr(name).as_raw(), ty, value) } } /// Gets the type for the name in the registry key. pub fn get_type>(&self, name: T) -> Result { - let name = pcwstr(name); - let mut ty = 0; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - &mut ty, - null_mut(), - null_mut(), - ) - }; - - win32_error(result)?; - - Ok(match ty { - REG_DWORD => Type::U32, - REG_QWORD => Type::U64, - REG_BINARY => Type::Bytes, - REG_SZ | REG_EXPAND_SZ => Type::String, - REG_MULTI_SZ => Type::MultiString, - rest => Type::Unknown(rest), - }) + let (ty, _) = unsafe { self.raw_get_info(pcwstr(name).as_raw())? }; + Ok(ty) } /// Gets the value for the name in the registry key. pub fn get_value>(&self, name: T) -> Result { let name = pcwstr(name); - let mut ty = 0; - let mut len = 0; - - let result = unsafe { - RegQueryValueExW(self.0, name.as_ptr(), null(), &mut ty, null_mut(), &mut len) - }; - - win32_error(result)?; - - match ty { - REG_DWORD if len == 4 => { - let mut value = 0u32; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - &mut value as *mut _ as _, - &mut len, - ) - }; - - win32_error(result)?; - Ok(Value::U32(value)) - } - REG_QWORD if len == 8 => { - let mut value = 0u64; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - &mut value as *mut _ as _, - &mut len, - ) - }; - - win32_error(result)?; - Ok(Value::U64(value)) - } - REG_SZ | REG_EXPAND_SZ => { - let mut value = vec![0u16; len as usize / 2]; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - value.as_mut_ptr() as _, - &mut len, - ) - }; - - win32_error(result)?; - Ok(Value::String(String::from_utf16_lossy(trim(&value)))) - } - REG_MULTI_SZ => { - let mut value = vec![0u16; len as usize / 2]; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - value.as_mut_ptr() as _, - &mut len, - ) - }; - - win32_error(result)?; - - Ok(Value::MultiString( - trim(&value) - .split(|c| *c == 0) - .map(String::from_utf16_lossy) - .collect(), - )) - } - REG_BINARY => { - let mut value = vec![0u8; len as usize]; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - value.as_mut_ptr() as _, - &mut len, - ) - }; - - win32_error(result)?; - Ok(Value::Bytes(value)) - } - _ => Err(invalid_data()), - } + let (ty, len) = unsafe { self.raw_get_info(name.as_raw())? }; + let mut data = Data::new(len)?; + unsafe { self.raw_get_bytes(name.as_raw(), &mut data)? }; + Ok(Value { data, ty }) } /// Gets the value for the name in the registry key. pub fn get_u32>(&self, name: T) -> Result { - if let Value::U32(value) = self.get_value(name)? { - Ok(value) - } else { - Err(invalid_data()) - } + Ok(self.get_u64(name)?.try_into()?) } /// Gets the value for the name in the registry key. pub fn get_u64>(&self, name: T) -> Result { - if let Value::U64(value) = self.get_value(name)? { - Ok(value) - } else { - Err(invalid_data()) - } + let value = &mut [0; 8]; + let (ty, value) = unsafe { self.raw_get_bytes(pcwstr(name).as_raw(), value)? }; + from_le_bytes(ty, value) } /// Gets the value for the name in the registry key. pub fn get_string>(&self, name: T) -> Result { - if let Value::String(value) = self.get_value(name)? { - Ok(value) - } else { - Err(invalid_data()) - } + self.get_value(name)?.try_into() } /// Gets the value for the name in the registry key. pub fn get_hstring>(&self, name: T) -> Result { let name = pcwstr(name); - let mut ty = 0; - let mut len = 0; - - let result = unsafe { - RegQueryValueExW(self.0, name.as_ptr(), null(), &mut ty, null_mut(), &mut len) - }; - - win32_error(result)?; + let (ty, len) = unsafe { self.raw_get_info(name.as_raw())? }; - if !matches!(ty, REG_SZ | REG_EXPAND_SZ) { + if !matches!(ty, Type::String | Type::ExpandString) { return Err(invalid_data()); } - let mut value = HStringBuilder::new(len as usize / 2)?; - - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - value.as_mut_ptr() as _, - &mut len, - ) - }; - - win32_error(result)?; + let mut value = HStringBuilder::new(len / 2)?; + unsafe { self.raw_get_bytes(name.as_raw(), value.as_bytes_mut())? }; value.trim_end(); Ok(value.into()) } /// Gets the value for the name in the registry key. - pub fn get_bytes>(&self, name: T) -> Result> { - let name = pcwstr(name); - let mut len = 0; + pub fn get_multi_string>(&self, name: T) -> Result> { + self.get_value(name)?.try_into() + } - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - null_mut(), - &mut len, - ) - }; + /// Sets the name and value in the registry key. + /// + /// This method avoids any allocations. + /// + /// # Safety + /// + /// The `PCWSTR` pointer needs to be valid for reads up until and including the next `\0`. + pub unsafe fn raw_set_bytes>( + &self, + name: N, + ty: Type, + value: &[u8], + ) -> Result<()> { + let result = RegSetValueExW( + self.0, + name.as_ref().as_ptr(), + 0, + ty.into(), + value.as_ptr(), + value.len().try_into()?, + ); - win32_error(result)?; - let mut value = vec![0u8; len as usize]; + win32_error(result) + } - let result = unsafe { - RegQueryValueExW( - self.0, - name.as_ptr(), - null(), - null_mut(), - value.as_mut_ptr() as _, - &mut len, - ) - }; + /// Gets the type and length for the name in the registry key. + /// + /// This method avoids any allocations. + /// + /// # Safety + /// + /// The `PCWSTR` pointer needs to be valid for reads up until and including the next `\0`. + pub unsafe fn raw_get_info>(&self, name: N) -> Result<(Type, usize)> { + let mut ty = 0; + let mut len = 0; - win32_error(result).map(|_| value) - } + let result = RegQueryValueExW( + self.0, + name.as_ref().as_ptr(), + null(), + &mut ty, + core::ptr::null_mut(), + &mut len, + ); - /// Gets the value for the name in the registry key. - pub fn get_multi_string>(&self, name: T) -> Result> { - if let Value::MultiString(value) = self.get_value(name)? { - Ok(value) - } else { - Err(invalid_data()) - } + win32_error(result)?; + Ok((ty.into(), len as usize)) } - unsafe fn set_value>( + /// Gets the value for the name in the registry key. + /// + /// This method avoids any allocations. + /// + /// # Safety + /// + /// The `PCWSTR` pointer needs to be valid for reads up until and including the next `\0`. + pub unsafe fn raw_get_bytes<'a, N: AsRef>( &self, - name: T, - ty: REG_VALUE_TYPE, - ptr: *const u8, - len: usize, - ) -> Result<()> { - let result = RegSetValueExW(self.0, pcwstr(name).as_ptr(), 0, ty, ptr, len.try_into()?); + name: N, + value: &'a mut [u8], + ) -> Result<(Type, &'a [u8])> { + let mut ty = 0; + let mut len = value.len().try_into()?; - win32_error(result) + let result = RegQueryValueExW( + self.0, + name.as_ref().as_ptr(), + null(), + &mut ty, + value.as_mut_ptr(), + &mut len, + ); + + win32_error(result)?; + Ok((ty.into(), value.get(0..len as usize).unwrap())) } } diff --git a/crates/libs/registry/src/key_iterator.rs b/crates/libs/registry/src/key_iterator.rs index 70d728e2b8..aca76c5ff1 100644 --- a/crates/libs/registry/src/key_iterator.rs +++ b/crates/libs/registry/src/key_iterator.rs @@ -1,5 +1,4 @@ use super::*; -use core::ptr::{null, null_mut}; /// An iterator of registry key names. pub struct KeyIterator<'a> { diff --git a/crates/libs/registry/src/lib.rs b/crates/libs/registry/src/lib.rs index a5e479e356..6d04f58022 100644 --- a/crates/libs/registry/src/lib.rs +++ b/crates/libs/registry/src/lib.rs @@ -9,6 +9,8 @@ Learn more about Rust for Windows here: >(value: T) -> Vec { - value - .as_ref() - .encode_utf16() - .chain(core::iter::once(0)) - .collect() -} - -fn trim(mut value: &[u16]) -> &[u16] { - while value.last() == Some(&0) { - value = &value[..value.len() - 1]; - } - value -} - fn win32_error(result: u32) -> Result<()> { if result == 0 { Ok(()) @@ -75,3 +68,11 @@ fn win32_error(result: u32) -> Result<()> { fn invalid_data() -> Error { Error::from_hresult(HRESULT::from_win32(ERROR_INVALID_DATA)) } + +fn from_le_bytes(ty: Type, from: &[u8]) -> Result { + match ty { + Type::U32 if from.len() == 4 => Ok(u32::from_le_bytes(from.try_into().unwrap()).into()), + Type::U64 if from.len() == 8 => Ok(u64::from_le_bytes(from.try_into().unwrap())), + _ => Err(invalid_data()), + } +} diff --git a/crates/libs/registry/src/pcwstr.rs b/crates/libs/registry/src/pcwstr.rs new file mode 100644 index 0000000000..10fb656e3f --- /dev/null +++ b/crates/libs/registry/src/pcwstr.rs @@ -0,0 +1,41 @@ +use super::*; + +pub struct OwnedPcwstr(Vec); + +pub fn pcwstr>(value: T) -> OwnedPcwstr { + OwnedPcwstr( + value + .as_ref() + .encode_utf16() + .chain(core::iter::once(0)) + .collect(), + ) +} + +pub fn multi_pcwstr>(value: &[T]) -> OwnedPcwstr { + OwnedPcwstr( + value + .iter() + .flat_map(|value| value.as_ref().encode_utf16().chain(core::iter::once(0))) + .chain(core::iter::once(0)) + .collect(), + ) +} + +impl OwnedPcwstr { + pub fn as_ptr(&self) -> *const u16 { + debug_assert!( + self.0.last() == Some(&0), + "`OwnedPcwstr` isn't null-terminated" + ); + self.0.as_ptr() + } + + pub fn as_bytes(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self.as_ptr() as *const _, self.0.len() * 2) } + } + + pub fn as_raw(&self) -> PCWSTR { + PCWSTR(self.as_ptr()) + } +} diff --git a/crates/libs/registry/src/type.rs b/crates/libs/registry/src/type.rs index 050e59b11b..7f665db41b 100644 --- a/crates/libs/registry/src/type.rs +++ b/crates/libs/registry/src/type.rs @@ -1,5 +1,7 @@ +use super::*; + /// The possible types that a registry value could have. -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum Type { /// A 32-bit unsigned integer value. U32, @@ -10,12 +12,43 @@ pub enum Type { /// A string value. String, - /// An array u8 bytes. - Bytes, + /// A string value that may contain unexpanded environment variables. + ExpandString, /// An array of string values. MultiString, - /// An unknown or unsupported type. - Unknown(u32), + /// An array u8 bytes. + Bytes, + + /// An unknown type. + Other(u32), +} + +impl From for Type { + fn from(ty: u32) -> Self { + match ty { + REG_DWORD => Self::U32, + REG_QWORD => Self::U64, + REG_SZ => Self::String, + REG_EXPAND_SZ => Self::ExpandString, + REG_MULTI_SZ => Self::MultiString, + REG_BINARY => Self::Bytes, + rest => Self::Other(rest), + } + } +} + +impl From for u32 { + fn from(ty: Type) -> Self { + match ty { + Type::U32 => REG_DWORD, + Type::U64 => REG_QWORD, + Type::String => REG_SZ, + Type::ExpandString => REG_EXPAND_SZ, + Type::MultiString => REG_MULTI_SZ, + Type::Bytes => REG_BINARY, + Type::Other(other) => other, + } + } } diff --git a/crates/libs/registry/src/value.rs b/crates/libs/registry/src/value.rs index eb196d2463..8d200358c7 100644 --- a/crates/libs/registry/src/value.rs +++ b/crates/libs/registry/src/value.rs @@ -2,22 +2,122 @@ use super::*; /// A registry value. #[derive(Clone, PartialEq, Eq, Debug)] -pub enum Value { - /// A 32-bit unsigned integer value. - U32(u32), +pub struct Value { + pub(crate) data: Data, + pub(crate) ty: Type, +} + +impl Value { + /// Gets the type of the registry value. + pub fn ty(&self) -> Type { + self.ty + } + + /// Sets the type of the registry value. This does not change the value. + pub fn set_ty(&mut self, ty: Type) { + self.ty = ty; + } +} - /// A 64-bit unsigned integer value. - U64(u64), +impl core::ops::Deref for Value { + type Target = [u8]; - /// A string value. - String(String), + fn deref(&self) -> &[u8] { + &self.data + } +} - /// An array u8 bytes. - Bytes(Vec), +impl AsRef<[u8]> for Value { + fn as_ref(&self) -> &[u8] { + &self.data + } +} - /// An array of string values. - MultiString(Vec), +impl TryFrom for Value { + type Error = Error; + fn try_from(from: u32) -> Result { + Ok(Self { + data: from.to_le_bytes().try_into()?, + ty: Type::U32, + }) + } +} + +impl TryFrom for u32 { + type Error = Error; + fn try_from(from: Value) -> Result { + Ok(from_le_bytes(from.ty, &from)?.try_into()?) + } +} + +impl TryFrom for Value { + type Error = Error; + fn try_from(from: u64) -> Result { + Ok(Self { + data: from.to_le_bytes().try_into()?, + ty: Type::U64, + }) + } +} + +impl TryFrom for u64 { + type Error = Error; + fn try_from(from: Value) -> Result { + from_le_bytes(from.ty, &from) + } +} + +impl TryFrom for String { + type Error = Error; + fn try_from(from: Value) -> Result { + match from.ty { + Type::String | Type::ExpandString => Ok(Self::from_utf16(from.data.as_wide())?), + _ => Err(invalid_data()), + } + } +} + +impl TryFrom<&str> for Value { + type Error = Error; + fn try_from(from: &str) -> Result { + Ok(Self { + data: Data::from_slice(pcwstr(from).as_bytes())?, + ty: Type::String, + }) + } +} + +impl TryFrom for Vec { + type Error = Error; + fn try_from(from: Value) -> Result { + match from.ty { + Type::MultiString => Ok(from + .data + .as_wide() + .split(|c| *c == 0) + .map(String::from_utf16_lossy) + .collect()), + _ => Ok(vec![String::try_from(from)?]), + } + } +} + +impl TryFrom<&[u8]> for Value { + type Error = Error; + fn try_from(from: &[u8]) -> Result { + Ok(Self { + data: Data::from_slice(from)?, + ty: Type::Bytes, + }) + } +} - /// An unknown or unsupported type. - Unknown(u32), +impl TryFrom<[u8; N]> for Value { + type Error = Error; + fn try_from(from: [u8; N]) -> Result { + Ok(Self { + data: Data::from_slice(&from)?, + ty: Type::Bytes, + }) + } } diff --git a/crates/libs/registry/src/value_iterator.rs b/crates/libs/registry/src/value_iterator.rs index 429ca90af3..f66f9724c3 100644 --- a/crates/libs/registry/src/value_iterator.rs +++ b/crates/libs/registry/src/value_iterator.rs @@ -1,12 +1,11 @@ use super::*; -use core::ptr::null_mut; /// An iterator of registry values. pub struct ValueIterator<'a> { key: &'a Key, range: core::ops::Range, name: Vec, - value: ValueBytes, + data: Data, } impl<'a> ValueIterator<'a> { @@ -38,7 +37,7 @@ impl<'a> ValueIterator<'a> { key, range: 0..count as usize, name: vec![0; name_max_len as usize + 1], - value: ValueBytes::new(value_max_len as usize)?, + data: Data::new(value_max_len as usize)?, }) } } @@ -50,7 +49,7 @@ impl<'a> Iterator for ValueIterator<'a> { self.range.next().and_then(|index| { let mut ty = 0; let mut name_len = self.name.len() as u32; - let mut value_len = self.value.len() as u32; + let mut data_len = self.data.len() as u32; let result = unsafe { RegEnumValueW( @@ -60,8 +59,8 @@ impl<'a> Iterator for ValueIterator<'a> { &mut name_len, core::ptr::null(), &mut ty, - self.value.as_mut_ptr(), - &mut value_len, + self.data.as_mut_ptr(), + &mut data_len, ) }; @@ -69,89 +68,15 @@ impl<'a> Iterator for ValueIterator<'a> { debug_assert_eq!(result, ERROR_NO_MORE_ITEMS); None } else { - let value = match ty { - REG_DWORD if value_len == 4 => { - Value::U32(u32::from_le_bytes(self.value[0..4].try_into().unwrap())) - } - REG_QWORD if value_len == 8 => { - Value::U64(u64::from_le_bytes(self.value[0..8].try_into().unwrap())) - } - REG_BINARY => Value::Bytes(self.value[0..value_len as usize].to_vec()), - REG_SZ | REG_EXPAND_SZ => { - if value_len == 0 { - Value::String(String::new()) - } else { - let value = unsafe { - core::slice::from_raw_parts( - self.value.as_ptr() as *const u16, - value_len as usize / 2, - ) - }; - - Value::String(String::from_utf16_lossy(trim(value))) - } - } - REG_MULTI_SZ => { - if value_len == 0 { - Value::MultiString(vec![]) - } else { - let value = unsafe { - core::slice::from_raw_parts( - self.value.as_ptr() as *const u16, - value_len as usize / 2, - ) - }; - - Value::MultiString( - trim(value) - .split(|c| *c == 0) - .map(String::from_utf16_lossy) - .collect(), - ) - } - } - rest => Value::Unknown(rest), - }; - let name = String::from_utf16_lossy(&self.name[0..name_len as usize]); - Some((name, value)) + Some(( + name, + Value { + data: Data::from_slice(&self.data[0..data_len as usize]).unwrap(), + ty: ty.into(), + }, + )) } }) } } - -// Minimal `Vec` replacement providing `u16` alignment. -struct ValueBytes(*mut core::ffi::c_void, usize); - -impl ValueBytes { - fn new(len: usize) -> Result { - // This pointer will have at least 8 byte alignment. - let ptr = unsafe { HeapAlloc(GetProcessHeap(), 0, len) }; - - if ptr.is_null() { - Err(Error::from_hresult(HRESULT::from_win32(ERROR_OUTOFMEMORY))) - } else { - Ok(Self(ptr, len)) - } - } - - fn as_mut_ptr(&mut self) -> *mut u8 { - self.0 as *mut u8 - } -} - -impl Drop for ValueBytes { - fn drop(&mut self) { - unsafe { - HeapFree(GetProcessHeap(), 0, self.0); - }; - } -} - -impl core::ops::Deref for ValueBytes { - type Target = [u8]; - - fn deref(&self) -> &[u8] { - unsafe { core::slice::from_raw_parts(self.0 as *const u8, self.1) } - } -} diff --git a/crates/libs/strings/src/hstring.rs b/crates/libs/strings/src/hstring.rs index 6cd0a1ebc4..c6511ef27e 100644 --- a/crates/libs/strings/src/hstring.rs +++ b/crates/libs/strings/src/hstring.rs @@ -33,6 +33,11 @@ impl HSTRING { unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) } } + /// Get the string as 8-bit bytes. + pub fn as_bytes(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self.as_ptr() as *const _, self.len() * 2) } + } + /// Returns a raw pointer to the `HSTRING` buffer. pub fn as_ptr(&self) -> *const u16 { if let Some(header) = self.as_header() { diff --git a/crates/libs/strings/src/hstring_builder.rs b/crates/libs/strings/src/hstring_builder.rs index 23679bb36a..7b3d576cee 100644 --- a/crates/libs/strings/src/hstring_builder.rs +++ b/crates/libs/strings/src/hstring_builder.rs @@ -36,6 +36,17 @@ impl HStringBuilder { } } + /// Allows the `HSTRING` to be constructed from bytes. + pub fn as_bytes_mut(&mut self) -> &mut [u8] { + if let Some(header) = self.as_header() { + unsafe { + core::slice::from_raw_parts_mut(header.data as *mut _, header.len as usize * 2) + } + } else { + &mut [] + } + } + fn as_header(&self) -> Option<&HStringHeader> { unsafe { self.0.as_ref() } } diff --git a/crates/libs/strings/src/pcwstr.rs b/crates/libs/strings/src/pcwstr.rs index 687413d394..63196e42b2 100644 --- a/crates/libs/strings/src/pcwstr.rs +++ b/crates/libs/strings/src/pcwstr.rs @@ -83,3 +83,9 @@ impl PCWSTR { Decode(move || core::char::decode_utf16(self.as_wide().iter().cloned())) } } + +impl AsRef for PCWSTR { + fn as_ref(&self) -> &Self { + self + } +} diff --git a/crates/tests/registry/tests/bad_string.rs b/crates/tests/registry/tests/bad_string.rs index aeaeebafcf..2cd4cb3d0b 100644 --- a/crates/tests/registry/tests/bad_string.rs +++ b/crates/tests/registry/tests/bad_string.rs @@ -3,16 +3,17 @@ use windows_registry::*; #[test] fn bad_string() -> Result<()> { - let bad_string_bytes = vec![ + let test_key = "software\\windows-rs\\tests\\bad_string"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + // Test value taken from https://github.com/rust-lang/rustup/blob/master/tests/suite/cli_paths.rs + let bad_string_bytes = [ 0x00, 0xD8, // leading surrogate 0x01, 0x01, // bogus trailing surrogate 0x00, 0x00, // null ]; - let test_key = "software\\windows-rs\\tests\\bad_string"; - _ = CURRENT_USER.remove_tree(test_key); - let key = CURRENT_USER.create(test_key)?; - unsafe { RegSetValueExW( HKEY(key.as_raw()), @@ -21,20 +22,17 @@ fn bad_string() -> Result<()> { REG_SZ, Some(&bad_string_bytes), ) - .ok()? - }; + .ok()?; + } let ty = key.get_type("name")?; assert_eq!(ty, Type::String); - let value_as_string = key.get_string("name")?; - assert_eq!(value_as_string, "�ā"); - - let value_as_bytes = key.get_bytes("name")?; - assert_eq!(value_as_bytes, bad_string_bytes); - let value_as_hstring = key.get_hstring("name")?; assert_eq!(value_as_hstring.to_string_lossy(), "�ā"); + let value = key.get_value("name")?; + assert_eq!(*value, bad_string_bytes); + Ok(()) } diff --git a/crates/tests/registry/tests/bytes.rs b/crates/tests/registry/tests/bytes.rs new file mode 100644 index 0000000000..dd10c4b3d4 --- /dev/null +++ b/crates/tests/registry/tests/bytes.rs @@ -0,0 +1,24 @@ +use windows_registry::*; + +#[test] +fn bytes() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\bytes"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + key.set_bytes("bytes", Type::Bytes, &[1, 2, 3])?; + assert_eq!(key.get_type("bytes")?, Type::Bytes); + + let value = key.get_value("bytes")?; + assert_eq!(value.ty(), Type::Bytes); + assert_eq!(*value, [1, 2, 3]); + + key.set_bytes("other", Type::Other(1234), &[1, 2, 3, 4])?; + assert_eq!(key.get_type("other")?, Type::Other(1234)); + + let value = key.get_value("other")?; + assert_eq!(value.ty(), Type::Other(1234)); + assert_eq!(*value, [1, 2, 3, 4]); + + Ok(()) +} diff --git a/crates/tests/registry/tests/sys_interop.rs b/crates/tests/registry/tests/interop.rs similarity index 50% rename from crates/tests/registry/tests/sys_interop.rs rename to crates/tests/registry/tests/interop.rs index 52c10b1ae9..b9507567d3 100644 --- a/crates/tests/registry/tests/sys_interop.rs +++ b/crates/tests/registry/tests/interop.rs @@ -1,12 +1,13 @@ use windows_registry::*; -use windows_sys::Win32::System::Registry::*; #[test] fn sys_interop() -> Result<()> { + use windows_sys::Win32::System::Registry::*; + let test_key = "software\\windows-rs\\tests\\sys_interop"; _ = CURRENT_USER.remove_tree(test_key); - let key = CURRENT_USER.create(test_key)?; + key.set_u32("1", 1)?; key.set_u32("2", 2)?; key.set_u32("3", 3)?; @@ -37,3 +38,43 @@ fn sys_interop() -> Result<()> { assert_eq!(count, 3); Ok(()) } + +#[test] +fn windows_interop() -> Result<()> { + use windows::{core::*, Win32::System::Registry::*}; + + let test_key = "software\\windows-rs\\tests\\windows_interop"; + _ = CURRENT_USER.remove_tree(test_key); + + let key = CURRENT_USER.create(test_key)?; + key.set_u32("1", 1)?; + key.set_u32("2", 2)?; + key.set_u32("3", 3)?; + + let raw = HKEY(key.as_raw()); + std::mem::forget(key); + let owned = unsafe { Key::from_raw(raw.0) }; + + let mut count = 0; + + unsafe { + RegQueryInfoKeyW( + HKEY(owned.as_raw()), + PWSTR::null(), + None, + None, + None, + None, + None, + Some(&mut count), + None, + None, + None, + None, + ) + .ok()?; + }; + + assert_eq!(count, 3); + Ok(()) +} diff --git a/crates/tests/registry/tests/raw.rs b/crates/tests/registry/tests/raw.rs new file mode 100644 index 0000000000..843dbe6368 --- /dev/null +++ b/crates/tests/registry/tests/raw.rs @@ -0,0 +1,35 @@ +use windows_registry::*; +use windows_result::*; +use windows_strings::*; + +#[test] +fn raw() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\raw"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + unsafe { + key.raw_set_bytes(w!("raw"), Type::Other(1234), &[1, 2, 3])?; + + let (ty, len) = key.raw_get_info(w!("raw"))?; + assert_eq!(ty, Type::Other(1234)); + assert_eq!(len, 3); + + let mut bytes = [0; 3]; + key.raw_get_bytes(w!("raw"), &mut bytes)?; + assert_eq!(bytes, [1, 2, 3]); + + let mut larger = [0; 5]; + let (ty, slice) = key.raw_get_bytes(w!("raw"), &mut larger)?; + assert_eq!(ty, Type::Other(1234)); + assert_eq!(slice, [1, 2, 3]); + assert_eq!(larger, [1, 2, 3, 0, 0]); + + let mut bytes = [0; 2]; + let err = key.raw_get_bytes(w!("raw"), &mut bytes).unwrap_err(); + assert_eq!(err.code(), HRESULT(0x800700EAu32 as i32)); // HRESULT_FROM_WIN32(ERROR_INVALID_DATA) + assert_eq!(err.message(), "More data is available."); + } + + Ok(()) +} diff --git a/crates/tests/registry/tests/string.rs b/crates/tests/registry/tests/string.rs new file mode 100644 index 0000000000..41e507b6a8 --- /dev/null +++ b/crates/tests/registry/tests/string.rs @@ -0,0 +1,35 @@ +use windows_registry::*; +use windows_strings::*; + +#[test] +fn string() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\string"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + key.set_string("string", "value")?; + assert_eq!(key.get_type("string")?, Type::String); + assert_eq!(key.get_string("string")?, "value"); + assert_eq!(key.get_hstring("string")?, "value"); + assert_eq!(key.get_multi_string("string")?, ["value".to_string()]); + + key.set_hstring("hstring", h!("value"))?; + assert_eq!(key.get_type("hstring")?, Type::String); + assert_eq!(key.get_string("hstring")?, "value"); + assert_eq!(key.get_hstring("hstring")?, "value"); + assert_eq!(key.get_multi_string("hstring")?, ["value".to_string()]); + + key.set_expand_string("expand_string", "value")?; + assert_eq!(key.get_type("expand_string")?, Type::ExpandString); + assert_eq!(key.get_string("expand_string")?, "value"); + assert_eq!(key.get_hstring("expand_string")?, "value"); + assert_eq!(key.get_multi_string("expand_string")?, ["value"]); + + key.set_expand_hstring("expand_hstring", h!("value"))?; + assert_eq!(key.get_type("expand_hstring")?, Type::ExpandString); + assert_eq!(key.get_string("expand_hstring")?, "value"); + assert_eq!(key.get_hstring("expand_hstring")?, "value"); + assert_eq!(key.get_multi_string("expand_hstring")?, ["value"]); + + Ok(()) +} diff --git a/crates/tests/registry/tests/u32.rs b/crates/tests/registry/tests/u32.rs new file mode 100644 index 0000000000..b09186389e --- /dev/null +++ b/crates/tests/registry/tests/u32.rs @@ -0,0 +1,15 @@ +use windows_registry::*; + +#[test] +fn u32() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\u32"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + key.set_u32("u32", 123u32)?; + assert_eq!(key.get_type("u32")?, Type::U32); + assert_eq!(key.get_u32("u32")?, 123u32); + assert_eq!(key.get_u64("u32")?, 123u64); + + Ok(()) +} diff --git a/crates/tests/registry/tests/u64.rs b/crates/tests/registry/tests/u64.rs new file mode 100644 index 0000000000..c096b7aa00 --- /dev/null +++ b/crates/tests/registry/tests/u64.rs @@ -0,0 +1,15 @@ +use windows_registry::*; + +#[test] +fn u64() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\u64"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + key.set_u64("u64", 123u64)?; + assert_eq!(key.get_type("u64")?, Type::U64); + assert_eq!(key.get_u32("u64")?, 123u32); + assert_eq!(key.get_u64("u64")?, 123u64); + + Ok(()) +} diff --git a/crates/tests/registry/tests/value.rs b/crates/tests/registry/tests/value.rs new file mode 100644 index 0000000000..afb7456238 --- /dev/null +++ b/crates/tests/registry/tests/value.rs @@ -0,0 +1,45 @@ +use windows_registry::*; + +#[test] +fn value() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\value"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + key.set_value("u32", &Value::try_from(123u32)?)?; + assert_eq!(key.get_type("u32")?, Type::U32); + assert_eq!(key.get_value("u32")?, Value::try_from(123u32)?); + assert_eq!(key.get_u32("u32")?, 123u32); + assert_eq!(key.get_u64("u32")?, 123u64); + + key.set_value("u64", &Value::try_from(123u64)?)?; + assert_eq!(key.get_type("u64")?, Type::U64); + assert_eq!(key.get_value("u64")?, Value::try_from(123u64)?); + assert_eq!(key.get_u32("u64")?, 123u32); + assert_eq!(key.get_u64("u64")?, 123u64); + + key.set_value("string", &Value::try_from("string")?)?; + assert_eq!(key.get_type("string")?, Type::String); + assert_eq!(key.get_value("string")?, Value::try_from("string")?); + assert_eq!(key.get_string("string")?, "string"); + + let mut value = Value::try_from("expand")?; + value.set_ty(Type::ExpandString); + assert_eq!(value.ty(), Type::ExpandString); + key.set_value("expand", &value)?; + assert_eq!(key.get_type("expand")?, Type::ExpandString); + assert_eq!(key.get_value("expand")?, value); + assert_eq!(key.get_string("expand")?, "expand"); + + key.set_value("bytes", &Value::try_from([1u8, 2u8, 3u8])?)?; + assert_eq!(key.get_type("bytes")?, Type::Bytes); + assert_eq!(key.get_value("bytes")?, Value::try_from([1, 2, 3])?); + + let mut value = Value::try_from([1u8, 2u8, 3u8, 4u8].as_slice())?; + value.set_ty(Type::Other(1234)); + key.set_value("slice", &value)?; + assert_eq!(key.get_type("slice")?, Type::Other(1234)); + assert_eq!(key.get_value("slice")?, value); + + Ok(()) +} diff --git a/crates/tests/registry/tests/values.rs b/crates/tests/registry/tests/values.rs index d78256d3b4..e846138abe 100644 --- a/crates/tests/registry/tests/values.rs +++ b/crates/tests/registry/tests/values.rs @@ -1,5 +1,4 @@ use windows_registry::*; -use windows_result::*; #[test] fn values() -> Result<()> { @@ -9,58 +8,17 @@ fn values() -> Result<()> { let key = CURRENT_USER.create(test_key)?; key.set_u32("u32", 123)?; key.set_u64("u64", 456)?; - key.set_string("string", "hello")?; - key.set_bytes("bytes", &[1u8, 2u8, 2u8])?; - key.set_multi_string("multi", &["hello", "world"])?; - - assert_eq!(key.get_u32("u32")?, 123u32); - assert_eq!(key.get_u64("u64")?, 456u64); - assert_eq!(key.get_string("string")?, "hello".to_string()); - assert_eq!(key.get_bytes("bytes")?, vec![1u8, 2u8, 2u8]); - assert_eq!( - key.get_multi_string("multi")?, - vec!["hello".to_string(), "world".to_string()] - ); - - let err = key.get_u32("string").unwrap_err(); - assert_eq!(err.code(), HRESULT(0x8007000Du32 as i32)); // HRESULT_FROM_WIN32(ERROR_INVALID_DATA) - assert_eq!(err.message(), "The data is invalid."); - - assert_eq!(key.get_value("u32")?, Value::U32(123)); - assert_eq!(key.get_value("u64")?, Value::U64(456)); - assert_eq!(key.get_value("string")?, Value::String("hello".to_string())); - assert_eq!(key.get_value("bytes")?, Value::Bytes(vec![1u8, 2u8, 2u8])); - assert_eq!( - key.get_value("multi")?, - Value::MultiString(vec!["hello".to_string(), "world".to_string()]) - ); + key.set_string("string", "hello world")?; let names: Vec<(String, Value)> = key.values()?.collect(); + assert_eq!(names.len(), 3); assert_eq!( names, [ - ("u32".to_string(), Value::U32(123)), - ("u64".to_string(), Value::U64(456)), - ("string".to_string(), Value::String("hello".to_string())), - ("bytes".to_string(), Value::Bytes(vec![1u8, 2u8, 2u8])), - ( - "multi".to_string(), - Value::MultiString(vec!["hello".to_string(), "world".to_string()]) - ), - ] - ); - - key.remove_value("string")?; - key.remove_value("multi")?; - let names: Vec<_> = key.values()?.collect(); - - assert_eq!( - names, - [ - ("u32".to_string(), Value::U32(123)), - ("u64".to_string(), Value::U64(456)), - ("bytes".to_string(), Value::Bytes(vec![1u8, 2u8, 2u8])), + ("u32".to_string(), Value::try_from(123u32)?), + ("u64".to_string(), Value::try_from(456u64)?), + ("string".to_string(), Value::try_from("hello world")?), ] ); diff --git a/crates/tests/registry/tests/windows_interop.rs b/crates/tests/registry/tests/windows_interop.rs deleted file mode 100644 index 2d0e3628ea..0000000000 --- a/crates/tests/registry/tests/windows_interop.rs +++ /dev/null @@ -1,40 +0,0 @@ -use windows::{core::*, Win32::System::Registry::*}; -use windows_registry::*; - -#[test] -fn windows_interop() -> Result<()> { - let test_key = "software\\windows-rs\\tests\\windows_interop"; - _ = CURRENT_USER.remove_tree(test_key); - - let key = CURRENT_USER.create(test_key)?; - key.set_u32("1", 1)?; - key.set_u32("2", 2)?; - key.set_u32("3", 3)?; - - let raw = HKEY(key.as_raw()); - std::mem::forget(key); - let owned = unsafe { Key::from_raw(raw.0) }; - - let mut count = 0; - - unsafe { - RegQueryInfoKeyW( - HKEY(owned.as_raw()), - PWSTR::null(), - None, - None, - None, - None, - None, - Some(&mut count), - None, - None, - None, - None, - ) - .ok()?; - }; - - assert_eq!(count, 3); - Ok(()) -}