diff --git a/crates/mun_abi/src/autogen.rs b/crates/mun_abi/src/autogen.rs index 886c87586..090fb6fc0 100644 --- a/crates/mun_abi/src/autogen.rs +++ b/crates/mun_abi/src/autogen.rs @@ -119,16 +119,22 @@ impl DispatchTable { } } - pub unsafe fn set_ptr_unchecked(&mut self, idx: u32, ptr: *const c_void) { - *self.fn_ptrs.offset(idx as isize) = ptr; - } - - pub fn set_ptr(&mut self, idx: u32, ptr: *const c_void) -> bool { + /// Returns a mutable reference to a function pointer, without doing bounds checking. + /// + /// This is generally not recommended, use with caution! Calling this method with an + /// out-of-bounds index is _undefined behavior_ even if the resulting reference is not used. + /// For a safe alternative see [get_ptr_mut](#method.get_ptr_mut). + pub unsafe fn get_ptr_unchecked_mut(&self, idx: u32) -> &mut *const c_void { + &mut *self.fn_ptrs.offset(idx as isize) + } + + /// Returns a mutable reference to a function pointer at the given index, or `None` if out of + /// bounds. + pub fn get_ptr_mut(&self, idx: u32) -> Option<&mut *const c_void> { if idx < self.num_entries { - unsafe { self.set_ptr_unchecked(idx, ptr) }; - true + Some(unsafe { self.get_ptr_unchecked_mut(idx) }) } else { - false + None } } } @@ -456,7 +462,7 @@ mod tests { } #[test] - fn test_dispatch_table_set_ptr_unchecked() { + fn test_dispatch_table_get_ptr_unchecked_mut() { let type_name = CString::new(FAKE_TYPE_NAME).expect("Invalid fake type name."); let type_info = fake_type_info(&type_name); @@ -467,16 +473,15 @@ mod tests { let signatures = &[fn_signature]; let fn_ptrs = &mut [ptr::null()]; - let mut dispatch_table = fake_dispatch_table(signatures, fn_ptrs); - assert_eq!(unsafe { dispatch_table.get_ptr_unchecked(0) }, fn_ptrs[0]); - - let ptr = 0xffffffff as *const c_void; - unsafe { dispatch_table.set_ptr_unchecked(0, ptr) }; - assert_eq!(unsafe { dispatch_table.get_ptr_unchecked(0) }, ptr); + let dispatch_table = fake_dispatch_table(signatures, fn_ptrs); + assert_eq!( + unsafe { dispatch_table.get_ptr_unchecked_mut(0) }, + &mut fn_ptrs[0] + ); } #[test] - fn test_dispatch_table_set_ptr_none() { + fn test_dispatch_table_get_ptr_mut_none() { let type_name = CString::new(FAKE_TYPE_NAME).expect("Invalid fake type name."); let type_info = fake_type_info(&type_name); @@ -487,15 +492,12 @@ mod tests { let signatures = &[fn_signature]; let fn_ptrs = &mut [ptr::null()]; - let mut dispatch_table = fake_dispatch_table(signatures, fn_ptrs); - assert_eq!(dispatch_table.get_ptr(1), None); - - let ptr = 0xffffffff as *const c_void; - assert_eq!(dispatch_table.set_ptr(1, ptr), false); + let dispatch_table = fake_dispatch_table(signatures, fn_ptrs); + assert_eq!(dispatch_table.get_ptr_mut(1), None); } #[test] - fn test_dispatch_table_set_ptr_some() { + fn test_dispatch_table_get_ptr_mut_some() { let type_name = CString::new(FAKE_TYPE_NAME).expect("Invalid fake type name."); let type_info = fake_type_info(&type_name); @@ -506,12 +508,8 @@ mod tests { let signatures = &[fn_signature]; let fn_ptrs = &mut [ptr::null()]; - let mut dispatch_table = fake_dispatch_table(signatures, fn_ptrs); - assert_eq!(dispatch_table.get_ptr(0), Some(fn_ptrs[0])); - - let ptr = 0xffffffff as *const c_void; - assert_eq!(dispatch_table.set_ptr(0, ptr), true); - assert_eq!(dispatch_table.get_ptr(0), Some(ptr)); + let dispatch_table = fake_dispatch_table(signatures, fn_ptrs); + assert_eq!(dispatch_table.get_ptr_mut(0), Some(&mut fn_ptrs[0])); } fn fake_assembly_info(