From 8a40875daaa7021fc64d51309062271856aee310 Mon Sep 17 00:00:00 2001 From: cat_or_not <41955154+catornot@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:50:15 -0400 Subject: [PATCH] probably safe threads impl --- src/bindings/squirrelclasstypes.rs | 12 +++ src/bindings/squirrelfunctions.rs | 12 +++ src/high/squirrel.rs | 144 +++++++++++++++++++++++++++++ 3 files changed, 168 insertions(+) diff --git a/src/bindings/squirrelclasstypes.rs b/src/bindings/squirrelclasstypes.rs index 733bb91..7e9a05d 100644 --- a/src/bindings/squirrelclasstypes.rs +++ b/src/bindings/squirrelclasstypes.rs @@ -403,3 +403,15 @@ pub type RegisterSquirrelFuncType_External = unsafe extern "C" fn( ) -> i64; pub type sq_createscriptinstanceType = unsafe extern "C" fn(ent: *mut ::std::os::raw::c_void) -> *mut SQObject; +pub type sq_suspendthreadType = unsafe extern "C" fn( + thread_sqvm: *mut HSquirrelVM, + *const *mut HSquirrelVM, + usize, + *mut HSquirrelVM, +) -> SQRESULT; +pub type sq_threadwakeupType = unsafe extern "C" fn( + thread_sqvm: *mut HSquirrelVM, + i32, + *const ::std::os::raw::c_void, + *mut HSquirrelVM, +) -> SQRESULT; diff --git a/src/bindings/squirrelfunctions.rs b/src/bindings/squirrelfunctions.rs index 7071674..7d913b5 100644 --- a/src/bindings/squirrelfunctions.rs +++ b/src/bindings/squirrelfunctions.rs @@ -54,6 +54,9 @@ offset_functions! { sq_pushnewstructinstance = sq_pushnewstructinstanceType where offset(0x53e0); sq_sealstructslot = sq_sealstructslotType where offset(0x5510); + + sq_suspendthread = sq_suspendthreadType where offset(0x434f0); + sq_threadwakeup = sq_threadwakeupType where offset(0x8780); } } @@ -110,6 +113,9 @@ offset_functions! { sq_pushnewstructinstance = sq_pushnewstructinstanceType where offset(0x5400); sq_sealstructslot = sq_sealstructslotType where offset(0x5530); + + sq_suspendthread = sq_suspendthreadType where offset(0x43550); + sq_threadwakeup = sq_threadwakeupType where offset(0x87b0); } } @@ -155,6 +161,8 @@ pub struct SquirrelFunctions { pub sq_create_script_instance: sq_createscriptinstanceType, pub sq_pushnewstructinstance: sq_pushnewstructinstanceType, pub sq_sealstructslot: sq_sealstructslotType, + pub sq_suspendthread: sq_suspendthreadType, + pub sq_threadwakeup: sq_threadwakeupType, } impl From<&ClientSQFunctions> for SquirrelFunctions { @@ -199,6 +207,8 @@ impl From<&ClientSQFunctions> for SquirrelFunctions { sq_create_script_instance: val.sq_create_script_instance, sq_pushnewstructinstance: val.sq_pushnewstructinstance, sq_sealstructslot: val.sq_sealstructslot, + sq_suspendthread: val.sq_suspendthread, + sq_threadwakeup: val.sq_threadwakeup, } } } @@ -245,6 +255,8 @@ impl From<&ServerSQFunctions> for SquirrelFunctions { sq_create_script_instance: val.sq_create_script_instance, sq_pushnewstructinstance: val.sq_pushnewstructinstance, sq_sealstructslot: val.sq_sealstructslot, + sq_suspendthread: val.sq_suspendthread, + sq_threadwakeup: val.sq_threadwakeup, } } } diff --git a/src/high/squirrel.rs b/src/high/squirrel.rs index 6df6e3a..58f4ed4 100644 --- a/src/high/squirrel.rs +++ b/src/high/squirrel.rs @@ -411,6 +411,150 @@ impl<'a, T> DerefMut for UserDataRef<'a, T> { } } +/// suspends a thread when returned from a native sqfunction +pub struct SuspendThread { + phantom: PhantomData, +} + +impl SQVMName for SuspendThread { + fn get_sqvm_name() -> String { + T::get_sqvm_name() + } +} + +impl PushToSquirrelVm for SuspendThread { + fn push_to_sqvm(self, sqvm: NonNull, sqfunctions: &SquirrelFunctions) { + unsafe { (sqfunctions.sq_suspendthread)(sqvm.as_ptr(), &sqvm.as_ptr(), 5, sqvm.as_ptr()) }; + } +} + +impl SuspendThread { + const fn new() -> Self { + Self { + phantom: PhantomData, + } + } + + fn is_thread_and_throw_error( + thread_sqvm: NonNull, + sqfunctions: &SquirrelFunctions, + ) -> bool { + use super::squirrel_traits::ReturnToVm; + let mut is_thread = true; + + // idk if this is how to check it + if 2 < unsafe { thread_sqvm.as_ref()._suspended } { + Err::("Cannot suspend thread from within code function calls".to_string()) + .return_to_vm(thread_sqvm, sqfunctions); + is_thread = false + } + + is_thread + } + + /// Spawns a native thread. + /// When completed it resumes the thread + /// + /// # SAFETY + /// this thread cannot live long the parent sqvm + #[cfg(feature = "async_engine")] + pub fn new_with_thread(thread_sqvm: NonNull, mut thread_func: F) -> Self + where + F: FnMut() -> T + Send + 'static, + T: Send + Sync + 'static, + { + use crate::high::engine_sync::{async_execute, AsyncEngineMessage}; + let thread_sqvm = unsafe { UnsafeHandle::new(thread_sqvm) }; + + if !Self::is_thread_and_throw_error(thread_sqvm, SQFUNCTIONS.from_sqvm(thread_sqvm)) { + return Self::new(); + } + + std::thread::spawn(move || { + let result = thread_func(); + + // TODO: check if the sqvm has expired + async_execute(AsyncEngineMessage::run_func(move |_| { + let thread_sqvm = thread_sqvm.take(); + let sq_functions = SQFUNCTIONS.from_sqvm(thread_sqvm); + + result.push_to_sqvm(thread_sqvm, sq_functions); + unsafe { resume_thread(thread_sqvm, sq_functions) }; + })) + }); + + Self::new() + } + + /// calls a function to store the ref to wake up this thread sqvm + /// + /// the stored [`ThreadWakeUp`] cannot outlive the parent sqvm + pub fn new_with_store(thread_sqvm: NonNull, mut store_func: F) -> Self + where + F: FnMut(ThreadWakeUp), + { + if !Self::is_thread_and_throw_error(thread_sqvm, SQFUNCTIONS.from_sqvm(thread_sqvm)) { + return Self::new(); + } + + store_func(ThreadWakeUp { + thread_sqvm, + phantom: PhantomData::, + }); + + Self::new() + } + + /// returns a [`ThreadWakeUp`] for this thread + /// + /// # Failure + /// + /// fails to return [`ThreadWakeUp`] if the sqvm is not a thread + pub fn new_both(thread_sqvm: NonNull) -> (Self, Option>) { + if !Self::is_thread_and_throw_error(thread_sqvm, SQFUNCTIONS.from_sqvm(thread_sqvm)) { + return (Self::new(), None); + } + + ( + Self::new(), + Some(ThreadWakeUp { + thread_sqvm, + phantom: PhantomData::, + }), + ) + } +} + +/// stores the thread sqvm to wake up and the return type in the generic +pub struct ThreadWakeUp { + thread_sqvm: NonNull, + phantom: PhantomData, +} + +impl ThreadWakeUp { + /// resumes the sqvm thread + /// + /// ub if it outlived the parent sqvm + pub fn resume(self, data: T) { + let sq_functions = SQFUNCTIONS.from_sqvm(self.thread_sqvm); + data.push_to_sqvm(self.thread_sqvm, sq_functions); + unsafe { resume_thread(self.thread_sqvm, sq_functions) }; + } +} + +/// # SAFETY +/// has to be valid and cannot live long the parent sqvm +unsafe fn resume_thread(thread_sqvm: NonNull, sqfunctions: &SquirrelFunctions) { + unsafe { + _ = (sqfunctions.sq_threadwakeup)( + thread_sqvm.as_ptr(), + 5, + std::ptr::null(), + thread_sqvm.as_ptr(), + ) + } +} + /// Adds a sqfunction to the registration list /// /// The sqfunction will be registered when its vm is loaded