Skip to content

Commit

Permalink
improving code comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Dec 27, 2024
1 parent a71f682 commit 16bb406
Show file tree
Hide file tree
Showing 7 changed files with 391 additions and 22 deletions.
63 changes: 54 additions & 9 deletions whisper_ros/include/silero_vad/silero_vad_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#ifndef SILERO_VAD_SILERO_VAD_NODE_HPP
#define SILERO_VAD_SILERO_VAD_NODE_HPP
#ifndef SILERO_VAD__SILERO_VAD_NODE_HPP
#define SILERO_VAD__SILERO_VAD_NODE_HPP

#include <atomic>
#include <memory>
Expand All @@ -36,50 +36,95 @@

namespace silero_vad {

/// @class SileroVadNode
/// @brief A ROS 2 lifecycle node for performing voice activity detection.
class SileroVadNode : public rclcpp_lifecycle::LifecycleNode {

public:
/// @brief Constructs a new SileroVadNode object.
SileroVadNode();

/// @brief Callback for configuring the lifecycle node.
/// @param state The current state of the node.
/// @return Success or failure of configuration.
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_configure(const rclcpp_lifecycle::State &);
on_configure(const rclcpp_lifecycle::State &state);

/// @brief Callback for activating the lifecycle node.
/// @param state The current state of the node.
/// @return Success or failure of activation.
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_activate(const rclcpp_lifecycle::State &);
on_activate(const rclcpp_lifecycle::State &state);

/// @brief Callback for deactivating the lifecycle node.
/// @param state The current state of the node.
/// @return Success or failure of deactivation.
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_deactivate(const rclcpp_lifecycle::State &);
on_deactivate(const rclcpp_lifecycle::State &state);

/// @brief Callback for cleaning up the lifecycle node.
/// @param state The current state of the node.
/// @return Success or failure of cleanup.
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_cleanup(const rclcpp_lifecycle::State &);
on_cleanup(const rclcpp_lifecycle::State &state);

/// @brief Callback for shutting down the lifecycle node.
/// @param state The current state of the node.
/// @return Success or failure of shutdown.
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_shutdown(const rclcpp_lifecycle::State &);
on_shutdown(const rclcpp_lifecycle::State &state);

protected:
/// Indicates if VAD is enabled.
std::atomic<bool> enabled;
/// Indicates if VAD is in listening mode.
std::atomic<bool> listening;
/// Indicates if audio data should be published.
std::atomic<bool> publish;
/// Buffer for storing audio data.
std::vector<float> data;
/// Pointer to the VAD iterator.
std::unique_ptr<VadIterator> vad_iterator;

private:
/// Buffer for storing previous audio data.
std::vector<float> prev_data;
/// Path to the VAD model.
std::string model_path_;
/// Sampling rate of the audio data.
int sample_rate_;
/// Frame size in milliseconds.
int frame_size_ms_;
/// Threshold for VAD decision-making.
float threshold_;
/// Minimum silence duration in milliseconds.
int min_silence_ms_;
/// Padding duration for detected speech in milliseconds.
int speech_pad_ms_;

/// Publisher for VAD output.
rclcpp::Publisher<std_msgs::msg::Float32MultiArray>::SharedPtr publisher_;
/// Subscription for audio input.
rclcpp::Subscription<audio_common_msgs::msg::AudioStamped>::SharedPtr
subscription_;

/// Service for enabling/disabling VAD.
rclcpp::Service<std_srvs::srv::SetBool>::SharedPtr enable_srv_;

/// @brief Callback for handling incoming audio data.
/// @param msg The audio message containing the audio data.
void
audio_callback(const audio_common_msgs::msg::AudioStamped::SharedPtr msg);

/// @brief Callback for enabling/disabling the VAD.
/// @param request The service request containing the desired enable state.
/// @param response The service response indicating success or failure.
void enable_cb(const std::shared_ptr<std_srvs::srv::SetBool::Request> request,
std::shared_ptr<std_srvs::srv::SetBool::Response> response);

/// @brief Converts audio data to a float vector normalized to [-1.0, 1.0].
/// @tparam T The input audio data type.
/// @param input The input audio data.
/// @return A vector of normalized float audio data.
template <typename T>
std::vector<float> convert_to_float(const std::vector<T> &input) {
static_assert(std::is_integral<T>::value,
Expand Down Expand Up @@ -123,4 +168,4 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode {

} // namespace silero_vad

#endif
#endif
21 changes: 21 additions & 0 deletions whisper_ros/include/silero_vad/timestamp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,38 @@

namespace silero_vad {

/// @class Timestamp
/// @brief Represents a time interval with speech probability.
class Timestamp {
public:
/// The start time of the interval, in milliseconds.
int start;

/// The end time of the interval, in milliseconds.
int end;

/// The probability of speech detected in the interval, ranging from 0 to 1.
float speech_prob;

/// @brief Constructs a `Timestamp` object.
/// @param start The start time of the interval (default: -1).
/// @param end The end time of the interval (default: -1).
/// @param speech_prob The speech probability (default: 0).
Timestamp(int start = -1, int end = -1, float speech_prob = 0);

/// @brief Assigns the values of another `Timestamp` to this instance.
/// @param other The `Timestamp` to copy from.
/// @return A reference to this `Timestamp`.
Timestamp &operator=(const Timestamp &other);

/// @brief Compares two `Timestamp` objects for equality.
/// @param other The `Timestamp` to compare with.
/// @return `true` if the start and end times are equal; `false` otherwise.
bool operator==(const Timestamp &other) const;

/// @brief Converts the `Timestamp` object to a string representation.
/// @return A string representing the `Timestamp` in the format
/// `{start:...,end:...,prob:...}`.
std::string to_string() const;
};

Expand Down
72 changes: 68 additions & 4 deletions whisper_ros/include/silero_vad/vad_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,58 +33,122 @@

namespace silero_vad {

/// @class VadIterator
/// @brief Implements a Voice Activity Detection (VAD) iterator using an ONNX
/// model.
///
/// This class provides methods to load a pre-trained ONNX VAD model, process
/// audio data, and predict the presence of speech. It manages the model state
/// and handles input/output tensors for inference.
class VadIterator {

public:
/// @brief Constructs a VadIterator object.
///
/// @param model_path Path to the ONNX model file.
/// @param sample_rate The audio sample rate in Hz (default: 16000).
/// @param frame_size_ms Size of the audio frame in milliseconds (default:
/// 32).
/// @param threshold The threshold for speech detection (default: 0.5).
/// @param min_silence_ms Minimum silence duration in milliseconds to mark the
/// end of speech (default: 100).
/// @param speech_pad_ms Additional padding in milliseconds added to speech
/// segments (default: 30).
VadIterator(const std::string &model_path, int sample_rate = 16000,
int frame_size_ms = 32, float threshold = 0.5f,
int min_silence_ms = 100, int speech_pad_ms = 30);

/// @brief Resets the internal state of the model.
///
/// Clears the state, context, and resets internal flags related to speech
/// detection.
void reset_states();

/// @brief Processes audio data and predicts speech segments.
///
/// @param data A vector of audio samples (single-channel, float values).
/// @return A Timestamp object containing start and end times of detected
/// speech, or -1 for inactive values.
Timestamp predict(const std::vector<float> &data);

private:
/// ONNX Runtime environment.
Ort::Env env;
/// ONNX session options.
Ort::SessionOptions session_options;
/// ONNX session for running inference.
std::shared_ptr<Ort::Session> session;
/// Memory allocator for ONNX runtime.
Ort::AllocatorWithDefaultOptions allocator;
/// Memory info for tensor allocation.
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);

// Model configuration
/// Detection threshold for speech probability.
float threshold;
/// Audio sample rate in Hz.
int sample_rate;
/// Samples per millisecond.
int sr_per_ms;
/// Number of samples in a single frame.
int64_t window_size_samples;
/// Padding in samples added to speech segments.
int speech_pad_samples;
/// Minimum silence duration in samples to mark the end of speech.
unsigned int min_silence_samples;
/// Size of the context buffer.
int context_size;

// Model state
/// Indicates whether speech has been detected.
bool triggered = false;
/// Temporary end position during silence detection.
unsigned int temp_end = 0;
/// Current sample position in the input stream.
unsigned int current_sample = 0;
/// End sample of the last speech segment.
int prev_end = 0;
/// Start sample of the next speech segment.
int next_start = 0;

/// ONNX model input tensors.
std::vector<Ort::Value> ort_inputs;

/// Names of the input nodes in the ONNX model.
std::vector<const char *> input_node_names = {"input", "state", "sr"};

/// Input buffer for audio data and context.
std::vector<float> input;

/// Context buffer storing past audio samples.
std::vector<float> context;

/// Internal state of the model.
std::vector<float> state;

/// Sample rate tensor.
std::vector<int64_t> sr;

/// Dimensions for the input tensor.
int64_t input_node_dims[2] = {};

/// Dimensions for the state tensor.
const int64_t state_node_dims[3] = {2, 1, 128};

/// Dimensions for the sample rate tensor.
const int64_t sr_node_dims[1] = {1};

/// ONNX model output tensors.
std::vector<Ort::Value> ort_outputs;

/// Names of the output nodes in the ONNX model.
std::vector<const char *> output_node_names = {"output", "stateN"};

/// @brief Initializes the ONNX model session.
///
/// @param model_path Path to the ONNX model file.
/// @throws std::runtime_error If the ONNX session initialization fails.
void init_onnx_model(const std::string &model_path);
};

} // namespace silero_vad

#endif
#endif
49 changes: 48 additions & 1 deletion whisper_ros/include/whisper_ros/whisper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,87 @@
#include "grammar-parser.h"
#include "whisper.h"

/// Represents the result of a transcription operation.
struct TranscriptionOutput {
/// The transcribed text.
std::string text;

/// The confidence probability of the transcription.
float prob;
};

namespace whisper_ros {

/// Class for performing speech-to-text transcription using the Whisper model.
class Whisper {

public:
/// Constructs a Whisper object with the specified model and parameters.
/// @param model The path to the Whisper model file.
/// @param openvino_encode_device The OpenVINO device used for encoding.
/// @param n_processors Number of processors to use for parallel processing.
/// @param cparams Whisper context parameters.
/// @param wparams Whisper full parameters.
Whisper(const std::string &model, const std::string &openvino_encode_device,
int n_processors, const struct whisper_context_params &cparams,
const struct whisper_full_params &wparams);

/// Destructor to clean up resources used by the Whisper object.
~Whisper();

/// Transcribes the given audio data.
/// @param pcmf32 A vector of audio samples in 32-bit float format.
/// @return A TranscriptionOutput structure containing the transcription text
/// and confidence probability.
struct TranscriptionOutput transcribe(const std::vector<float> &pcmf32);

/// Trims leading and trailing whitespace from the input string.
/// @param s The string to trim.
/// @return The trimmed string.
std::string trim(const std::string &s);

/// Converts a timestamp to a string format.
/// @param t The timestamp in 10 ms units.
/// @param comma If true, use a comma as the decimal separator; otherwise, use
/// a period.
/// @return The formatted timestamp as a string.
std::string timestamp_to_str(int64_t t, bool comma = false);

/// Sets a grammar for transcription with a starting rule and penalty.
/// @param grammar The grammar rules as a string.
/// @param start_rule The starting rule for the grammar.
/// @param grammar_penalty A penalty factor for grammar violations.
/// @return True if the grammar is set successfully; false otherwise.
bool set_grammar(const std::string grammar, const std::string start_rule,
float grammar_penalty);

/// Resets the grammar to its default state.
void reset_grammar();

/// Sets an initial prompt for transcription.
/// @param prompt The initial prompt text.
void set_init_prompt(const std::string prompt);

/// Resets the initial prompt to its default state.
void reset_init_prompt();

protected:
/// Number of processors used for parallel processing.
int n_processors;

/// Parameters used for full transcription tasks.
struct whisper_full_params wparams;

/// The Whisper context.
struct whisper_context *ctx;

/// Parsed grammar state.
grammar_parser::parse_state grammar_parsed;

/// Grammar rules derived from the parsed grammar.
std::vector<const whisper_grammar_element *> grammar_rules;
};

} // namespace whisper_ros

#endif
#endif
Loading

0 comments on commit 16bb406

Please sign in to comment.