From 16bb40647368b01bab0fa4b76bf759f6d726d928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 12:24:44 +0100 Subject: [PATCH] improving code comments --- .../include/silero_vad/silero_vad_node.hpp | 63 ++++++++++++--- whisper_ros/include/silero_vad/timestamp.hpp | 21 +++++ .../include/silero_vad/vad_iterator.hpp | 72 ++++++++++++++++- whisper_ros/include/whisper_ros/whisper.hpp | 49 ++++++++++- .../include/whisper_ros/whisper_base_node.hpp | 81 +++++++++++++++++-- .../include/whisper_ros/whisper_node.hpp | 65 ++++++++++++++- .../whisper_ros/whisper_server_node.hpp | 62 +++++++++++++- 7 files changed, 391 insertions(+), 22 deletions(-) diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp index ba10e97..cfc8c5c 100644 --- a/whisper_ros/include/silero_vad/silero_vad_node.hpp +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef SILERO_VAD_SILERO_VAD_NODE_HPP -#define SILERO_VAD_SILERO_VAD_NODE_HPP +#ifndef SILERO_VAD__SILERO_VAD_NODE_HPP +#define SILERO_VAD__SILERO_VAD_NODE_HPP #include #include @@ -36,50 +36,95 @@ namespace silero_vad { +/// @class SileroVadNode +/// @brief A ROS 2 lifecycle node for performing voice activity detection. class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { public: + /// @brief Constructs a new SileroVadNode object. SileroVadNode(); + /// @brief Callback for configuring the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of configuration. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_configure(const rclcpp_lifecycle::State &); + on_configure(const rclcpp_lifecycle::State &state); + + /// @brief Callback for activating the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of activation. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_activate(const rclcpp_lifecycle::State &); + on_activate(const rclcpp_lifecycle::State &state); + + /// @brief Callback for deactivating the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of deactivation. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_deactivate(const rclcpp_lifecycle::State &); + on_deactivate(const rclcpp_lifecycle::State &state); + + /// @brief Callback for cleaning up the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of cleanup. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_cleanup(const rclcpp_lifecycle::State &); + on_cleanup(const rclcpp_lifecycle::State &state); + + /// @brief Callback for shutting down the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of shutdown. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_shutdown(const rclcpp_lifecycle::State &); + on_shutdown(const rclcpp_lifecycle::State &state); protected: + /// Indicates if VAD is enabled. std::atomic enabled; + /// Indicates if VAD is in listening mode. std::atomic listening; + /// Indicates if audio data should be published. std::atomic publish; + /// Buffer for storing audio data. std::vector data; + /// Pointer to the VAD iterator. std::unique_ptr vad_iterator; private: + /// Buffer for storing previous audio data. std::vector prev_data; + /// Path to the VAD model. std::string model_path_; + /// Sampling rate of the audio data. int sample_rate_; + /// Frame size in milliseconds. int frame_size_ms_; + /// Threshold for VAD decision-making. float threshold_; + /// Minimum silence duration in milliseconds. int min_silence_ms_; + /// Padding duration for detected speech in milliseconds. int speech_pad_ms_; + /// Publisher for VAD output. rclcpp::Publisher::SharedPtr publisher_; + /// Subscription for audio input. rclcpp::Subscription::SharedPtr subscription_; - + /// Service for enabling/disabling VAD. rclcpp::Service::SharedPtr enable_srv_; + /// @brief Callback for handling incoming audio data. + /// @param msg The audio message containing the audio data. void audio_callback(const audio_common_msgs::msg::AudioStamped::SharedPtr msg); + /// @brief Callback for enabling/disabling the VAD. + /// @param request The service request containing the desired enable state. + /// @param response The service response indicating success or failure. void enable_cb(const std::shared_ptr request, std::shared_ptr response); + /// @brief Converts audio data to a float vector normalized to [-1.0, 1.0]. + /// @tparam T The input audio data type. + /// @param input The input audio data. + /// @return A vector of normalized float audio data. template std::vector convert_to_float(const std::vector &input) { static_assert(std::is_integral::value, @@ -123,4 +168,4 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { } // namespace silero_vad -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/silero_vad/timestamp.hpp b/whisper_ros/include/silero_vad/timestamp.hpp index cd41b08..c6e7f31 100644 --- a/whisper_ros/include/silero_vad/timestamp.hpp +++ b/whisper_ros/include/silero_vad/timestamp.hpp @@ -27,17 +27,38 @@ namespace silero_vad { +/// @class Timestamp +/// @brief Represents a time interval with speech probability. class Timestamp { public: + /// The start time of the interval, in milliseconds. int start; + + /// The end time of the interval, in milliseconds. int end; + + /// The probability of speech detected in the interval, ranging from 0 to 1. float speech_prob; + /// @brief Constructs a `Timestamp` object. + /// @param start The start time of the interval (default: -1). + /// @param end The end time of the interval (default: -1). + /// @param speech_prob The speech probability (default: 0). Timestamp(int start = -1, int end = -1, float speech_prob = 0); + /// @brief Assigns the values of another `Timestamp` to this instance. + /// @param other The `Timestamp` to copy from. + /// @return A reference to this `Timestamp`. Timestamp &operator=(const Timestamp &other); + + /// @brief Compares two `Timestamp` objects for equality. + /// @param other The `Timestamp` to compare with. + /// @return `true` if the start and end times are equal; `false` otherwise. bool operator==(const Timestamp &other) const; + /// @brief Converts the `Timestamp` object to a string representation. + /// @return A string representing the `Timestamp` in the format + /// `{start:...,end:...,prob:...}`. std::string to_string() const; }; diff --git a/whisper_ros/include/silero_vad/vad_iterator.hpp b/whisper_ros/include/silero_vad/vad_iterator.hpp index a8345d4..6198921 100644 --- a/whisper_ros/include/silero_vad/vad_iterator.hpp +++ b/whisper_ros/include/silero_vad/vad_iterator.hpp @@ -33,58 +33,122 @@ namespace silero_vad { +/// @class VadIterator +/// @brief Implements a Voice Activity Detection (VAD) iterator using an ONNX +/// model. +/// +/// This class provides methods to load a pre-trained ONNX VAD model, process +/// audio data, and predict the presence of speech. It manages the model state +/// and handles input/output tensors for inference. class VadIterator { - public: + /// @brief Constructs a VadIterator object. + /// + /// @param model_path Path to the ONNX model file. + /// @param sample_rate The audio sample rate in Hz (default: 16000). + /// @param frame_size_ms Size of the audio frame in milliseconds (default: + /// 32). + /// @param threshold The threshold for speech detection (default: 0.5). + /// @param min_silence_ms Minimum silence duration in milliseconds to mark the + /// end of speech (default: 100). + /// @param speech_pad_ms Additional padding in milliseconds added to speech + /// segments (default: 30). VadIterator(const std::string &model_path, int sample_rate = 16000, int frame_size_ms = 32, float threshold = 0.5f, int min_silence_ms = 100, int speech_pad_ms = 30); + /// @brief Resets the internal state of the model. + /// + /// Clears the state, context, and resets internal flags related to speech + /// detection. void reset_states(); + + /// @brief Processes audio data and predicts speech segments. + /// + /// @param data A vector of audio samples (single-channel, float values). + /// @return A Timestamp object containing start and end times of detected + /// speech, or -1 for inactive values. Timestamp predict(const std::vector &data); private: + /// ONNX Runtime environment. Ort::Env env; + /// ONNX session options. Ort::SessionOptions session_options; + /// ONNX session for running inference. std::shared_ptr session; + /// Memory allocator for ONNX runtime. Ort::AllocatorWithDefaultOptions allocator; + /// Memory info for tensor allocation. Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); - // Model configuration + /// Detection threshold for speech probability. float threshold; + /// Audio sample rate in Hz. int sample_rate; + /// Samples per millisecond. int sr_per_ms; + /// Number of samples in a single frame. int64_t window_size_samples; + /// Padding in samples added to speech segments. int speech_pad_samples; + /// Minimum silence duration in samples to mark the end of speech. unsigned int min_silence_samples; + /// Size of the context buffer. int context_size; - // Model state + /// Indicates whether speech has been detected. bool triggered = false; + /// Temporary end position during silence detection. unsigned int temp_end = 0; + /// Current sample position in the input stream. unsigned int current_sample = 0; + /// End sample of the last speech segment. int prev_end = 0; + /// Start sample of the next speech segment. int next_start = 0; + /// ONNX model input tensors. std::vector ort_inputs; + + /// Names of the input nodes in the ONNX model. std::vector input_node_names = {"input", "state", "sr"}; + /// Input buffer for audio data and context. std::vector input; + + /// Context buffer storing past audio samples. std::vector context; + + /// Internal state of the model. std::vector state; + + /// Sample rate tensor. std::vector sr; + /// Dimensions for the input tensor. int64_t input_node_dims[2] = {}; + + /// Dimensions for the state tensor. const int64_t state_node_dims[3] = {2, 1, 128}; + + /// Dimensions for the sample rate tensor. const int64_t sr_node_dims[1] = {1}; + /// ONNX model output tensors. std::vector ort_outputs; + + /// Names of the output nodes in the ONNX model. std::vector output_node_names = {"output", "stateN"}; + /// @brief Initializes the ONNX model session. + /// + /// @param model_path Path to the ONNX model file. + /// @throws std::runtime_error If the ONNX session initialization fails. void init_onnx_model(const std::string &model_path); }; } // namespace silero_vad -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/whisper_ros/whisper.hpp b/whisper_ros/include/whisper_ros/whisper.hpp index 7e2ef40..dadae67 100644 --- a/whisper_ros/include/whisper_ros/whisper.hpp +++ b/whisper_ros/include/whisper_ros/whisper.hpp @@ -29,40 +29,87 @@ #include "grammar-parser.h" #include "whisper.h" +/// Represents the result of a transcription operation. struct TranscriptionOutput { + /// The transcribed text. std::string text; + + /// The confidence probability of the transcription. float prob; }; namespace whisper_ros { +/// Class for performing speech-to-text transcription using the Whisper model. class Whisper { public: + /// Constructs a Whisper object with the specified model and parameters. + /// @param model The path to the Whisper model file. + /// @param openvino_encode_device The OpenVINO device used for encoding. + /// @param n_processors Number of processors to use for parallel processing. + /// @param cparams Whisper context parameters. + /// @param wparams Whisper full parameters. Whisper(const std::string &model, const std::string &openvino_encode_device, int n_processors, const struct whisper_context_params &cparams, const struct whisper_full_params &wparams); + + /// Destructor to clean up resources used by the Whisper object. ~Whisper(); + /// Transcribes the given audio data. + /// @param pcmf32 A vector of audio samples in 32-bit float format. + /// @return A TranscriptionOutput structure containing the transcription text + /// and confidence probability. struct TranscriptionOutput transcribe(const std::vector &pcmf32); + + /// Trims leading and trailing whitespace from the input string. + /// @param s The string to trim. + /// @return The trimmed string. std::string trim(const std::string &s); + + /// Converts a timestamp to a string format. + /// @param t The timestamp in 10 ms units. + /// @param comma If true, use a comma as the decimal separator; otherwise, use + /// a period. + /// @return The formatted timestamp as a string. std::string timestamp_to_str(int64_t t, bool comma = false); + /// Sets a grammar for transcription with a starting rule and penalty. + /// @param grammar The grammar rules as a string. + /// @param start_rule The starting rule for the grammar. + /// @param grammar_penalty A penalty factor for grammar violations. + /// @return True if the grammar is set successfully; false otherwise. bool set_grammar(const std::string grammar, const std::string start_rule, float grammar_penalty); + + /// Resets the grammar to its default state. void reset_grammar(); + + /// Sets an initial prompt for transcription. + /// @param prompt The initial prompt text. void set_init_prompt(const std::string prompt); + + /// Resets the initial prompt to its default state. void reset_init_prompt(); protected: + /// Number of processors used for parallel processing. int n_processors; + + /// Parameters used for full transcription tasks. struct whisper_full_params wparams; + /// The Whisper context. struct whisper_context *ctx; + + /// Parsed grammar state. grammar_parser::parse_state grammar_parsed; + + /// Grammar rules derived from the parsed grammar. std::vector grammar_rules; }; } // namespace whisper_ros -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/whisper_ros/whisper_base_node.hpp b/whisper_ros/include/whisper_ros/whisper_base_node.hpp index aab1c12..03399b5 100644 --- a/whisper_ros/include/whisper_ros/whisper_base_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_base_node.hpp @@ -35,39 +35,108 @@ namespace whisper_ros { using CallbackReturn = rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn; +/** + * @class WhisperBaseNode + * @brief A ROS 2 Lifecycle Node for interfacing with the Whisper speech-to-text + * system. + */ class WhisperBaseNode : public rclcpp_lifecycle::LifecycleNode { public: + /** + * @brief Constructs a new WhisperBaseNode instance. + */ WhisperBaseNode(); + /** + * @brief Configures the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_configure(const rclcpp_lifecycle::State &); + on_configure(const rclcpp_lifecycle::State &state); + + /** + * @brief Activates the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_activate(const rclcpp_lifecycle::State &); + on_activate(const rclcpp_lifecycle::State &state); + + /** + * @brief Deactivates the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_deactivate(const rclcpp_lifecycle::State &); + on_deactivate(const rclcpp_lifecycle::State &state); + + /** + * @brief Cleans up resources during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_cleanup(const rclcpp_lifecycle::State &); + on_cleanup(const rclcpp_lifecycle::State &state); + + /** + * @brief Shuts down the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_shutdown(const rclcpp_lifecycle::State &); + on_shutdown(const rclcpp_lifecycle::State &state); + /** + * @brief Activates ROS 2 interfaces specific to this node. To be implemented + * in derived classes. + */ virtual void activate_ros_interfaces(){}; + + /** + * @brief Deactivates ROS 2 interfaces specific to this node. To be + * implemented in derived classes. + */ virtual void deactivate_ros_interfaces(){}; protected: + /// @brief The language for transcription (e.g., "en"). std::string language; + + /// @brief Shared pointer to the Whisper speech-to-text processor. std::shared_ptr whisper; + /** + * @brief Transcribes a given audio input to text. + * + * @param audio A vector of floating-point audio samples. + * @return A Transcription message containing the text and metadata. + */ whisper_msgs::msg::Transcription transcribe(const std::vector &audio); private: + /// @brief The Whisper model to use. std::string model; + + /// @brief Device identifier for OpenVINO encoding (e.g., "CPU"). std::string openvino_encode_device; + + /// @brief Number of processors to use. int n_processors; + + /// @brief Whisper context parameters. struct whisper_context_params cparams = whisper_context_default_params(); + + /// @brief Whisper full transcription parameters. struct whisper_full_params wparams; }; } // namespace whisper_ros -#endif \ No newline at end of file +#endif // WHISPER_ROS__WHISPER_BASE_NODE_HPP diff --git a/whisper_ros/include/whisper_ros/whisper_node.hpp b/whisper_ros/include/whisper_ros/whisper_node.hpp index d6ea89b..73f53b8 100644 --- a/whisper_ros/include/whisper_ros/whisper_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_node.hpp @@ -35,36 +35,99 @@ namespace whisper_ros { +/** + * @class WhisperNode + * @brief A ROS 2 node that extends WhisperBaseNode to handle + * transcription-related functionality. + */ class WhisperNode : public WhisperBaseNode { public: + /** + * @brief Constructs a WhisperNode instance. + */ WhisperNode(); + /** + * @brief Activates ROS 2 interfaces such as publishers, subscriptions, and + * services. + */ void activate_ros_interfaces(); + + /** + * @brief Deactivates ROS 2 interfaces by resetting and clearing publishers, + * subscriptions, and services. + */ void deactivate_ros_interfaces(); private: + /// Publisher for transcription messages. rclcpp::Publisher::SharedPtr publisher_; + + /// Subscription for voice activity detection (VAD) data. rclcpp::Subscription::SharedPtr subscription_; + /// Service for setting grammar configuration. rclcpp::Service::SharedPtr set_grammar_service_; + + /// Service for resetting grammar configuration. rclcpp::Service::SharedPtr reset_grammar_service_; + + /// Service for setting the initial transcription prompt. rclcpp::Service::SharedPtr set_init_prompt_service_; + + /// Service for resetting the initial transcription prompt. rclcpp::Service::SharedPtr reset_init_prompt_service_; + /** + * @brief Callback for processing voice activity detection (VAD) messages. + * + * @param msg A shared pointer to the received Float32MultiArray message. + */ void vad_callback(const std_msgs::msg::Float32MultiArray::SharedPtr msg); + + /** + * @brief Callback for the SetGrammar service. + * + * @param request Shared pointer to the request containing grammar + * configuration. + * @param response Shared pointer to the response indicating success or + * failure. + */ void set_grammar_service_callback( const std::shared_ptr request, std::shared_ptr response); + + /** + * @brief Callback for the ResetGrammar service. + * + * @param request Shared pointer to the Empty request (unused). + * @param response Shared pointer to the Empty response (unused). + */ void reset_grammar_service_callback( const std::shared_ptr request, std::shared_ptr response); + + /** + * @brief Callback for the SetInitPrompt service. + * + * @param request Shared pointer to the request containing the initial prompt + * text. + * @param response Shared pointer to the response indicating success. + */ void set_init_prompt_service_callback( const std::shared_ptr request, std::shared_ptr response); + + /** + * @brief Callback for the ResetInitPrompt service. + * + * @param request Shared pointer to the Empty request (unused). + * @param response Shared pointer to the Empty response (unused). + */ void reset_init_prompt_service_callback( const std::shared_ptr request, std::shared_ptr response); @@ -72,4 +135,4 @@ class WhisperNode : public WhisperBaseNode { } // namespace whisper_ros -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/whisper_ros/whisper_server_node.hpp b/whisper_ros/include/whisper_ros/whisper_server_node.hpp index 555d62d..6adbae7 100644 --- a/whisper_ros/include/whisper_ros/whisper_server_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_server_node.hpp @@ -37,41 +37,101 @@ namespace whisper_ros { +/** + * @class WhisperServerNode + * @brief A node providing speech-to-text (STT) functionality using Whisper and + * Silero VAD. + */ class WhisperServerNode : public WhisperBaseNode { using STT = whisper_msgs::action::STT; using GoalHandleSTT = rclcpp_action::ServerGoalHandle; public: + /** + * @brief Constructor for the WhisperServerNode class. + */ WhisperServerNode(); + /** + * @brief Activates ROS interfaces including subscriptions, action servers, + * and clients. + */ void activate_ros_interfaces(); + + /** + * @brief Deactivates ROS interfaces by releasing subscriptions, action + * servers, and clients. + */ void deactivate_ros_interfaces(); protected: + /** + * @brief Enables or disables Silero VAD. + * @param enable If true, enables Silero VAD; otherwise, disables it. + */ void enable_silero(bool enable); private: + /// The transcription message containing the converted text. whisper_msgs::msg::Transcription transcription_msg; + + /// Mutex for synchronizing access to the transcription message. std::mutex transcription_mutex; + + /// Condition variable for waiting on transcription updates. std::condition_variable transcription_cond; + /// Handle for the active STT goal. std::shared_ptr goal_handle_; + + /// Subscription to receive VAD data. rclcpp::Subscription::SharedPtr subscription_; + + /// Client to enable or disable Silero VAD. rclcpp::Client::SharedPtr enable_silero_client_; + + /// Action server for handling STT requests. rclcpp_action::Server::SharedPtr action_server_; + /** + * @brief Callback for processing VAD data. + * @param msg The incoming VAD data message. + */ void vad_callback(const std_msgs::msg::Float32MultiArray::SharedPtr msg); + + /** + * @brief Executes the STT goal. + * @param goal_handle The goal handle representing the active STT request. + */ void execute(const std::shared_ptr goal_handle); + + /** + * @brief Handles an incoming goal request. + * @param uuid The unique identifier of the goal. + * @param goal The goal details. + * @return Response indicating whether the goal is accepted or rejected. + */ rclcpp_action::GoalResponse handle_goal(const rclcpp_action::GoalUUID &uuid, std::shared_ptr goal); + + /** + * @brief Handles a cancellation request for an active goal. + * @param goal_handle The handle of the goal to be canceled. + * @return Response indicating whether the cancellation is accepted. + */ rclcpp_action::CancelResponse handle_cancel(const std::shared_ptr goal_handle); + + /** + * @brief Handles acceptance of a goal and starts execution. + * @param goal_handle The handle of the accepted goal. + */ void handle_accepted(const std::shared_ptr goal_handle); }; } // namespace whisper_ros -#endif \ No newline at end of file +#endif