diff --git a/src/lib.rs b/src/lib.rs index afd1996..a35ff93 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ pub use whisper_ctx::WhisperContextParameters; use whisper_ctx::WhisperInnerContext; pub use whisper_ctx_wrapper::WhisperContext; pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; -pub use whisper_params::{FullParams, SamplingStrategy}; +pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData}; #[cfg(feature = "raw-api")] pub use whisper_rs_sys; pub use whisper_state::WhisperState; diff --git a/src/whisper_params.rs b/src/whisper_params.rs index dbf0bb5..f0afb00 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -21,6 +21,16 @@ impl Default for SamplingStrategy { } } +#[derive(Debug, Clone)] +pub struct SegmentCallbackData { + pub segment: i32, + pub start_timestamp: i64, + pub end_timestamp: i64, + pub text: String, +} + +type SegmentCallbackFn = Box; + pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, phantom_lang: PhantomData<&'a str>, @@ -28,6 +38,7 @@ pub struct FullParams<'a, 'b> { grammar: Option>, progess_callback_safe: Option>, abort_callback_safe: Option bool>>, + segment_calllback_safe: Option, } impl<'a, 'b> FullParams<'a, 'b> { @@ -64,6 +75,7 @@ impl<'a, 'b> FullParams<'a, 'b> { grammar: None, progess_callback_safe: None, abort_callback_safe: None, + segment_calllback_safe: None, } } @@ -392,6 +404,140 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.new_segment_callback_user_data = user_data; } + /// Set the callback for segment updates. + /// + /// Provides a limited segment_callback to ensure safety. + /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state` + /// + /// Defaults to None. + pub fn set_segment_callback_safe(&mut self, closure: O) + where + F: FnMut(SegmentCallbackData) + 'static, + O: Into>, + { + use std::ffi::{c_void, CStr}; + use whisper_rs_sys::{whisper_context, whisper_state}; + + extern "C" fn trampoline( + _: *mut whisper_context, + state: *mut whisper_state, + n_new: i32, + user_data: *mut c_void, + ) where + F: FnMut(SegmentCallbackData) + 'static, + { + unsafe { + let user_data = &mut *(user_data as *mut SegmentCallbackFn); + let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); + let s0 = n_segments - n_new; + //let user_data = user_data as *mut Box; + + for i in s0..n_segments { + let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); + let text = CStr::from_ptr(text); + + let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); + let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); + + match text.to_str() { + Ok(n) => user_data(SegmentCallbackData { + segment: i, + start_timestamp: t0, + end_timestamp: t1, + text: n.to_string(), + }), + Err(_) => {} + } + } + } + } + + match closure.into() { + Some(closure) => { + // Stable address + let closure = Box::new(closure) as SegmentCallbackFn; + // Thin pointer + let closure = Box::new(closure); + // Raw pointer + let closure = Box::into_raw(closure); + + self.fp.new_segment_callback_user_data = closure as *mut c_void; + self.fp.new_segment_callback = Some(trampoline::); + self.segment_calllback_safe = None; + } + None => { + self.segment_calllback_safe = None; + self.fp.new_segment_callback = None; + self.fp.new_segment_callback_user_data = std::ptr::null_mut::(); + } + } + } + + /// Set the callback for segment updates. + /// + /// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters. + /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`. + /// + /// Defaults to None. + pub fn set_segment_callback_safe_lossy(&mut self, closure: O) + where + F: FnMut(SegmentCallbackData) + 'static, + O: Into>, + { + use std::ffi::{c_void, CStr}; + use whisper_rs_sys::{whisper_context, whisper_state}; + + extern "C" fn trampoline( + _: *mut whisper_context, + state: *mut whisper_state, + n_new: i32, + user_data: *mut c_void, + ) where + F: FnMut(SegmentCallbackData) + 'static, + { + unsafe { + let user_data = &mut *(user_data as *mut SegmentCallbackFn); + let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); + let s0 = n_segments - n_new; + //let user_data = user_data as *mut Box; + + for i in s0..n_segments { + let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); + let text = CStr::from_ptr(text); + + let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); + let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); + user_data(SegmentCallbackData { + segment: i, + start_timestamp: t0, + end_timestamp: t1, + text: text.to_string_lossy().to_string(), + }); + } + } + } + + match closure.into() { + Some(closure) => { + // Stable address + let closure = Box::new(closure) as SegmentCallbackFn; + // Thin pointer + let closure = Box::new(closure); + // Raw pointer + let closure = Box::into_raw(closure); + + self.fp.new_segment_callback_user_data = closure as *mut c_void; + self.fp.new_segment_callback = Some(trampoline::); + self.segment_calllback_safe = None; + } + None => { + self.segment_calllback_safe = None; + self.fp.new_segment_callback = None; + self.fp.new_segment_callback_user_data = std::ptr::null_mut::(); + } + } + } + /// Set the callback for progress updates. /// /// Note that is still a C callback.