From b3b3454055d104bf0ac18b2e01c4c0d83400ae10 Mon Sep 17 00:00:00 2001 From: 3 Date: Fri, 27 Oct 2023 21:38:11 -0700 Subject: [PATCH 1/7] add segment callback safe --- src/whisper_params.rs | 51 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 73e57be..f4ee545 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -25,6 +25,7 @@ pub struct FullParams<'a, 'b> { phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, progess_callback_safe: Option>, + segment_calllback_safe: Option> } impl<'a, 'b> FullParams<'a, 'b> { @@ -59,6 +60,7 @@ impl<'a, 'b> FullParams<'a, 'b> { phantom_lang: PhantomData, phantom_tokens: PhantomData, progess_callback_safe: None, + segment_calllback_safe: None, } } @@ -373,6 +375,54 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.new_segment_callback_user_data = user_data; } + pub fn set_segment_callback_safe(&mut self, closure: O) + where + F: FnMut(String) + 'static, + O: Into>, + { + use whisper_rs_sys::{whisper_context, whisper_state}; + use std::ffi::{c_void, CStr}; + + unsafe extern "C" fn trampoline( + _: *mut whisper_context, + state: *mut whisper_state, + n_new: i32, + user_data: *mut c_void, + ) + where + F: FnMut(String) + 'static + { + let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); + let s0 = n_segments - n_new; + let user_data = &mut *(user_data as *mut F); + + for i in s0..n_segments { + let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); + let text = CStr::from_ptr(text); + + match text.to_str() { + Ok(n) => { + user_data(n.to_string()); + }, + Err(_) => {} + } + } + } + + match closure.into() { + Some(mut closure) => { + self.fp.new_segment_callback_user_data = &mut closure as *mut F as *mut c_void; + self.fp.new_segment_callback = Some(trampoline::); + self.segment_calllback_safe = Some(Box::new(closure)); + }, + None => { + self.segment_calllback_safe = None; + self.fp.new_segment_callback = None; + self.fp.new_segment_callback_user_data = std::ptr::null_mut::(); + } + } + } + /// Set the callback for progress updates. /// /// Note that is still a C callback. @@ -433,6 +483,7 @@ impl<'a, 'b> FullParams<'a, 'b> { } } + /// Set the user data to be passed to the progress callback. /// /// # Safety From 890c31749058ff5964cfe70dc4bc84cba6d2c1c3 Mon Sep 17 00:00:00 2001 From: Astrid Date: Fri, 27 Oct 2023 21:54:41 -0700 Subject: [PATCH 2/7] add SegmentCallbackData --- src/whisper_params.rs | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index f4ee545..99e4f23 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -20,12 +20,20 @@ impl Default for SamplingStrategy { } } +#[derive(Debug, Clone)] +pub struct SegmentCallbackData { + segment: i32, + start_timestamp: i64, + end_timestamp: i64, + text: String, +} + pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, progess_callback_safe: Option>, - segment_calllback_safe: Option> + segment_calllback_safe: Option> } impl<'a, 'b> FullParams<'a, 'b> { @@ -375,9 +383,12 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.new_segment_callback_user_data = user_data; } + /// Set the callback for segment updates. + /// + /// Provides a limited segment_callback to ensure safety pub fn set_segment_callback_safe(&mut self, closure: O) where - F: FnMut(String) + 'static, + F: FnMut(SegmentCallbackData) + 'static, O: Into>, { use whisper_rs_sys::{whisper_context, whisper_state}; @@ -390,19 +401,30 @@ impl<'a, 'b> FullParams<'a, 'b> { user_data: *mut c_void, ) where - F: FnMut(String) + 'static + F: FnMut(SegmentCallbackData) + 'static { let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); let s0 = n_segments - n_new; let user_data = &mut *(user_data as *mut F); + let mut t0: i64; + let mut t1: i64; + for i in s0..n_segments { let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); let text = CStr::from_ptr(text); + t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); + t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); + match text.to_str() { Ok(n) => { - user_data(n.to_string()); + user_data(SegmentCallbackData { + segment: i, + start_timestamp: t0, + end_timestamp: t1, + text: n.to_string() + }); }, Err(_) => {} } From 35fd27f9dc2f0da092b8f060a49d5bc226d60c4d Mon Sep 17 00:00:00 2001 From: Astrid Date: Fri, 27 Oct 2023 21:58:22 -0700 Subject: [PATCH 3/7] Expose SegmentCallbackData --- src/lib.rs | 2 +- src/whisper_params.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8c3ac3c..bd8319d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ pub use error::WhisperError; pub use standalone::*; pub use utilities::*; pub use whisper_ctx::WhisperContext; -pub use whisper_params::{FullParams, SamplingStrategy}; +pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData}; pub use whisper_state::WhisperState; pub type WhisperSysContext = whisper_rs_sys::whisper_context; diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 99e4f23..6125e55 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -22,10 +22,10 @@ impl Default for SamplingStrategy { #[derive(Debug, Clone)] pub struct SegmentCallbackData { - segment: i32, - start_timestamp: i64, - end_timestamp: i64, - text: String, + pub segment: i32, + pub start_timestamp: i64, + pub end_timestamp: i64, + pub text: String, } pub struct FullParams<'a, 'b> { From b7387a8473caa6dd1fbc3a5b6f433617a46ce873 Mon Sep 17 00:00:00 2001 From: Astrid Date: Fri, 27 Oct 2023 22:42:37 -0700 Subject: [PATCH 4/7] document better --- src/whisper_params.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 6125e55..59f5b30 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -407,15 +407,12 @@ impl<'a, 'b> FullParams<'a, 'b> { let s0 = n_segments - n_new; let user_data = &mut *(user_data as *mut F); - let mut t0: i64; - let mut t1: i64; - for i in s0..n_segments { let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); let text = CStr::from_ptr(text); - t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); - t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); + let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); + let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); match text.to_str() { Ok(n) => { From 2e8a1847c70aa4c29e67922bb7454634e6b74421 Mon Sep 17 00:00:00 2001 From: Astrid Date: Fri, 27 Oct 2023 22:47:46 -0700 Subject: [PATCH 5/7] run cargo fmt --- src/whisper_params.rs | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 59f5b30..a34dca7 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -21,7 +21,7 @@ impl Default for SamplingStrategy { } #[derive(Debug, Clone)] -pub struct SegmentCallbackData { +pub struct SegmentCallbackData { pub segment: i32, pub start_timestamp: i64, pub end_timestamp: i64, @@ -33,7 +33,7 @@ pub struct FullParams<'a, 'b> { phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, progess_callback_safe: Option>, - segment_calllback_safe: Option> + segment_calllback_safe: Option>, } impl<'a, 'b> FullParams<'a, 'b> { @@ -384,24 +384,23 @@ impl<'a, 'b> FullParams<'a, 'b> { } /// Set the callback for segment updates. - /// + /// /// Provides a limited segment_callback to ensure safety pub fn set_segment_callback_safe(&mut self, closure: O) where F: FnMut(SegmentCallbackData) + 'static, O: Into>, { - use whisper_rs_sys::{whisper_context, whisper_state}; use std::ffi::{c_void, CStr}; + use whisper_rs_sys::{whisper_context, whisper_state}; unsafe extern "C" fn trampoline( _: *mut whisper_context, state: *mut whisper_state, n_new: i32, user_data: *mut c_void, - ) - where - F: FnMut(SegmentCallbackData) + 'static + ) where + F: FnMut(SegmentCallbackData) + 'static, { let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); let s0 = n_segments - n_new; @@ -416,13 +415,13 @@ impl<'a, 'b> FullParams<'a, 'b> { match text.to_str() { Ok(n) => { - user_data(SegmentCallbackData { - segment: i, - start_timestamp: t0, - end_timestamp: t1, - text: n.to_string() + user_data(SegmentCallbackData { + segment: i, + start_timestamp: t0, + end_timestamp: t1, + text: n.to_string(), }); - }, + } Err(_) => {} } } @@ -430,10 +429,10 @@ impl<'a, 'b> FullParams<'a, 'b> { match closure.into() { Some(mut closure) => { - self.fp.new_segment_callback_user_data = &mut closure as *mut F as *mut c_void; + self.fp.new_segment_callback_user_data = &mut closure as *mut F as *mut c_void; self.fp.new_segment_callback = Some(trampoline::); self.segment_calllback_safe = Some(Box::new(closure)); - }, + } None => { self.segment_calllback_safe = None; self.fp.new_segment_callback = None; @@ -502,7 +501,6 @@ impl<'a, 'b> FullParams<'a, 'b> { } } - /// Set the user data to be passed to the progress callback. /// /// # Safety From 9a1b043ec5eef249272c06339663065ce54216d9 Mon Sep 17 00:00:00 2001 From: Astrid Date: Fri, 27 Oct 2023 22:54:09 -0700 Subject: [PATCH 6/7] better document set_segment_callback_safe --- src/whisper_params.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index a34dca7..20dc2df 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -385,7 +385,10 @@ impl<'a, 'b> FullParams<'a, 'b> { /// Set the callback for segment updates. /// - /// Provides a limited segment_callback to ensure safety + /// Provides a limited segment_callback to ensure safety. + /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state` + /// + /// Defaults to None. pub fn set_segment_callback_safe(&mut self, closure: O) where F: FnMut(SegmentCallbackData) + 'static, From 969422c4a29359afd667c76f9fdbd8af37f38155 Mon Sep 17 00:00:00 2001 From: astrid Date: Sun, 29 Oct 2023 14:39:27 -0700 Subject: [PATCH 7/7] oops! No longer segfaults when using 'move' --- src/whisper_params.rs | 64 +++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 20dc2df..d258d83 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -28,12 +28,14 @@ pub struct SegmentCallbackData { pub text: String, } +type SegmentCallbackFn = Box; + pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, progess_callback_safe: Option>, - segment_calllback_safe: Option>, + segment_calllback_safe: Option, } impl<'a, 'b> FullParams<'a, 'b> { @@ -397,7 +399,7 @@ impl<'a, 'b> FullParams<'a, 'b> { use std::ffi::{c_void, CStr}; use whisper_rs_sys::{whisper_context, whisper_state}; - unsafe extern "C" fn trampoline( + extern "C" fn trampoline( _: *mut whisper_context, state: *mut whisper_state, n_new: i32, @@ -405,36 +407,46 @@ impl<'a, 'b> FullParams<'a, 'b> { ) where F: FnMut(SegmentCallbackData) + 'static, { - let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); - let s0 = n_segments - n_new; - let user_data = &mut *(user_data as *mut F); - - for i in s0..n_segments { - let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); - let text = CStr::from_ptr(text); - - let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); - let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); - - match text.to_str() { - Ok(n) => { - user_data(SegmentCallbackData { - segment: i, - start_timestamp: t0, - end_timestamp: t1, - text: n.to_string(), - }); + unsafe { + let user_data = &mut *(user_data as *mut SegmentCallbackFn); + let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); + let s0 = n_segments - n_new; + //let user_data = user_data as *mut Box; + + for i in s0..n_segments { + let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); + let text = CStr::from_ptr(text); + + let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); + let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); + + match text.to_str() { + Ok(n) => { + user_data(SegmentCallbackData { + segment: i, + start_timestamp: t0, + end_timestamp: t1, + text: n.to_string(), + }) + } + Err(_) => {} } - Err(_) => {} } } } match closure.into() { - Some(mut closure) => { - self.fp.new_segment_callback_user_data = &mut closure as *mut F as *mut c_void; - self.fp.new_segment_callback = Some(trampoline::); - self.segment_calllback_safe = Some(Box::new(closure)); + Some(closure) => { + // Stable address + let closure = Box::new(closure) as SegmentCallbackFn; + // Thin pointer + let closure = Box::new(closure); + // Raw pointer + let closure = Box::into_raw(closure); + + self.fp.new_segment_callback_user_data = closure as *mut c_void; + self.fp.new_segment_callback = Some(trampoline::); + self.segment_calllback_safe = None; } None => { self.segment_calllback_safe = None;