From 163d9a24939098136177775e49b0e699870bb477 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 15:26:53 +0100 Subject: [PATCH 01/18] initial changes for silero-vad-cpp --- Dockerfile | 8 +- README.md | 1 - onnxruntime_vendor/CMakeLists.txt | 37 +++ onnxruntime_vendor/package.xml | 18 ++ requirements.txt | 4 - whisper_bringup/launch/silero-vad.launch.py | 98 +++++++ whisper_bringup/launch/whisper.launch.py | 76 ++++-- whisper_ros/CMakeLists.txt | 37 ++- .../include/silero_vad/silero_vad_node.hpp | 126 +++++++++ whisper_ros/include/silero_vad/timestamp.hpp | 46 ++++ .../include/silero_vad/vad_iterator.hpp | 94 +++++++ whisper_ros/package.xml | 2 + .../src/silero_vad/silero_vad_node.cpp | 250 ++++++++++++++++++ whisper_ros/src/silero_vad/timestamp.cpp | 46 ++++ whisper_ros/src/silero_vad/vad_iterator.cpp | 144 ++++++++++ whisper_ros/src/silero_vad_main.cpp | 40 +++ whisper_ros/src/whisper_main.cpp | 2 - .../src/whisper_ros/whisper_base_node.cpp | 4 +- whisper_ros/src/whisper_ros/whisper_node.cpp | 4 +- .../src/whisper_ros/whisper_server_node.cpp | 1 - whisper_ros/src/whisper_server_main.cpp | 2 - 21 files changed, 990 insertions(+), 50 deletions(-) create mode 100644 onnxruntime_vendor/CMakeLists.txt create mode 100644 onnxruntime_vendor/package.xml delete mode 100644 requirements.txt create mode 100644 whisper_bringup/launch/silero-vad.launch.py create mode 100644 whisper_ros/include/silero_vad/silero_vad_node.hpp create mode 100644 whisper_ros/include/silero_vad/timestamp.hpp create mode 100644 whisper_ros/include/silero_vad/vad_iterator.hpp create mode 100644 whisper_ros/src/silero_vad/silero_vad_node.cpp create mode 100644 whisper_ros/src/silero_vad/timestamp.cpp create mode 100644 whisper_ros/src/silero_vad/vad_iterator.cpp create mode 100644 whisper_ros/src/silero_vad_main.cpp diff --git a/Dockerfile b/Dockerfile index 9b5ae54..2c9d6de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,16 +11,12 @@ RUN apt-get update \ && apt-get -y --quiet --no-install-recommends install \ gcc \ git \ - wget \ - portaudio19-dev \ - python3 \ - python3-pip + curl WORKDIR /root/ros2_ws/src RUN git clone https://github.com/mgonzs13/audio_common.git -WORKDIR /root/ros2_ws -RUN pip3 install -r src/requirements.txt +WORKDIR /root/ros2_ws RUN rosdep install --from-paths src --ignore-src -r -y # Install CUDA nvcc diff --git a/README.md b/README.md index 8c1da74..70e5d91 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ To run whisper_ros with CUDA, first, you must install the [CUDA Toolkit](https:/ $ cd ~/ros2_ws/src $ git clone https://github.com/mgonzs13/audio_common.git $ git clone https://github.com/mgonzs13/whisper_ros.git -$ pip3 install -r whisper_ros/requirements.txt $ cd ~/ros2_ws $ rosdep install --from-paths src --ignore-src -r -y $ colcon build --cmake-args -DGGML_CUDA=ON # add this for CUDA diff --git a/onnxruntime_vendor/CMakeLists.txt b/onnxruntime_vendor/CMakeLists.txt new file mode 100644 index 0000000..a667079 --- /dev/null +++ b/onnxruntime_vendor/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.8) +project(onnxruntime_vendor) + +# Set variables for the package +set(ONNXRUNTIME_VERSION "1.18.1") # Specify the desired ONNX Runtime version +set(ONNXRUNTIME_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz") + +# ROS 2 package configuration +find_package(ament_cmake REQUIRED) + +# Define a vendor package installation directory +set(ONNXRUNTIME_INSTALL_DIR "${CMAKE_BINARY_DIR}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}") + +# Add a custom target to download and extract the prebuilt ONNX Runtime +find_program(CURL_EXECUTABLE curl REQUIRED) + +if(NOT CURL_EXECUTABLE) + message(FATAL_ERROR "curl is required to download ONNX Runtime but was not found.") +endif() + +# Add custom command to download and extract the ONNX Runtime +add_custom_target(download_onnxruntime ALL + COMMENT "Downloading and extracting ONNX Runtime ${ONNXRUNTIME_VERSION}" + COMMAND ${CURL_EXECUTABLE} -L -o onnxruntime.tgz ${ONNXRUNTIME_URL} >/dev/null 2>&1 + COMMAND ${CMAKE_COMMAND} -E tar xzf onnxruntime.tgz +) + +# Install the ONNX Runtime library and include files +install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/lib) +install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/include) + +# Export the onnxruntime library for downstream packages +ament_export_include_directories(${ONNXRUNTIME_INSTALL_DIR}/include) +ament_export_libraries(onnxruntime) + +# Export the package +ament_package() \ No newline at end of file diff --git a/onnxruntime_vendor/package.xml b/onnxruntime_vendor/package.xml new file mode 100644 index 0000000..f30a7cc --- /dev/null +++ b/onnxruntime_vendor/package.xml @@ -0,0 +1,18 @@ + + + + onnxruntime_vendor + 1.3.1 + Vendor package for onnxruntime + Miguel Ángel González Santamarta + MIT + + ament_cmake + + ament_lint_auto + ament_lint_common + + + ament_cmake + + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 8207fe4..0000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -onnxruntime==1.18.1 -huggingface-hub==0.23.4 -silero-vad==5.1 -pyaudio==0.2.14 \ No newline at end of file diff --git a/whisper_bringup/launch/silero-vad.launch.py b/whisper_bringup/launch/silero-vad.launch.py new file mode 100644 index 0000000..9f08d6f --- /dev/null +++ b/whisper_bringup/launch/silero-vad.launch.py @@ -0,0 +1,98 @@ +# MIT License + +# Copyright (c) 2023 Miguel Ángel González Santamarta + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from launch_ros.actions import Node +from launch import LaunchDescription, LaunchContext +from launch.substitutions import LaunchConfiguration +from launch.actions import OpaqueFunction, DeclareLaunchArgument +from huggingface_hub import hf_hub_download + + +def generate_launch_description(): + + def run_silero_vad(context: LaunchContext, repo, file, model_path): + repo = str(context.perform_substitution(repo)) + file = str(context.perform_substitution(file)) + model_path = str(context.perform_substitution(model_path)) + + if not model_path: + model_path = hf_hub_download( + repo_id=repo, filename=file, force_download=False + ) + + return ( + Node( + package="whisper_ros", + executable="silero_vad_node", + name="silero_vad_node", + namespace="whisper", + parameters=[ + { + "enabled": LaunchConfiguration("enabled", default=True), + "model_path": model_path, + "sample_rate": LaunchConfiguration("sample_rate", default=16000), + "frame_size_ms": LaunchConfiguration("frame_size_ms", default=32), + "threshold": LaunchConfiguration("threshold", default=0.5), + "min_silence_ms": LaunchConfiguration( + "min_silence_ms", default=0 + ), + "speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=32), + "min_speech_ms": LaunchConfiguration("min_speech_ms", default=32), + "max_speech_s": LaunchConfiguration( + "max_speech_s", default=float("inf") + ), + } + ], + remappings=[("audio", "/audio/in")], + ), + ) + + model_repo = LaunchConfiguration("model_repo") + model_repo_cmd = DeclareLaunchArgument( + "model_repo", + default_value="deepghs/silero-vad-onnx", + description="Hugging Face model repo", + ) + + model_filename = LaunchConfiguration("model_filename") + model_filename_cmd = DeclareLaunchArgument( + "model_filename", + default_value="silero_vad.onnx", + description="Hugging Face model filename", + ) + + model_path = LaunchConfiguration("model_path") + model_path_cmd = DeclareLaunchArgument( + "model_path", default_value="", description="Local path to the model file" + ) + + return LaunchDescription( + [ + model_repo_cmd, + model_filename_cmd, + model_path_cmd, + OpaqueFunction( + function=run_silero_vad, args=[model_repo, model_filename, model_path] + ), + ] + ) diff --git a/whisper_bringup/launch/whisper.launch.py b/whisper_bringup/launch/whisper.launch.py index f470cd4..d38b555 100644 --- a/whisper_bringup/launch/whisper.launch.py +++ b/whisper_bringup/launch/whisper.launch.py @@ -21,12 +21,15 @@ # SOFTWARE. -from launch import LaunchDescription, LaunchContext +import os from launch_ros.actions import Node +from launch import LaunchDescription, LaunchContext +from launch.conditions import IfCondition, UnlessCondition from launch.substitutions import LaunchConfiguration, PythonExpression -from launch.actions import OpaqueFunction, DeclareLaunchArgument +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch.actions import OpaqueFunction, DeclareLaunchArgument, IncludeLaunchDescription +from ament_index_python.packages import get_package_share_directory from huggingface_hub import hf_hub_download -from launch.conditions import IfCondition, UnlessCondition def generate_launch_description(): @@ -126,19 +129,42 @@ def run_whisper(context: LaunchContext, repo, file, model_path): model_repo_cmd = DeclareLaunchArgument( "model_repo", default_value="ggerganov/whisper.cpp", - description="Hugging Face model repo", + description="Hugging Face model repo for Whisper", ) model_filename = LaunchConfiguration("model_filename") model_filename_cmd = DeclareLaunchArgument( "model_filename", default_value="ggml-large-v3-turbo-q5_0.bin", - description="Hugging Face model filename", + description="Hugging Face model filename for Whisper", ) model_path = LaunchConfiguration("model_path") model_path_cmd = DeclareLaunchArgument( - "model_path", default_value="", description="Local path to the model file" + "model_path", + default_value="", + description="Local path to the model file for Whisper", + ) + + silero_vad_model_repo = LaunchConfiguration("silero_vad_model_repo") + silero_vad_model_repo_cmd = DeclareLaunchArgument( + "silero_vad_model_repo", + default_value="onnx-community/silero-vad", + description="Hugging Face model repo for SileroVAD", + ) + + silero_vad_model_filename = LaunchConfiguration("silero_vad_model_filename") + silero_vad_model_filename_cmd = DeclareLaunchArgument( + "silero_vad_model_filename", + default_value="onnx/model.onnx", + description="Hugging Face model filename for SileroVAD", + ) + + silero_vad_model_path = LaunchConfiguration("silero_vad_model_path") + silero_vad_model_path_cmd = DeclareLaunchArgument( + "silero_vad_model_path", + default_value="", + description="Local path to the model file for SileroVAD", ) return LaunchDescription( @@ -147,24 +173,30 @@ def run_whisper(context: LaunchContext, repo, file, model_path): model_repo_cmd, model_filename_cmd, model_path_cmd, + silero_vad_model_repo_cmd, + silero_vad_model_filename_cmd, + silero_vad_model_path_cmd, OpaqueFunction( - function=run_whisper, args=[model_repo, model_filename, model_path] + function=run_whisper, + args=[model_repo, model_filename, model_path], ), - Node( - package="whisper_ros", - executable="silero_vad_node", - name="silero_vad_node", - namespace="whisper", - parameters=[ - { - "enabled": LaunchConfiguration( - "vad_enabled", - default=PythonExpression([LaunchConfiguration("stream")]), - ), - "threshold": LaunchConfiguration("vad_threshold", default=0.5), - } - ], - remappings=[("audio", "/audio/in")], + IncludeLaunchDescription( + PythonLaunchDescriptionSource( + os.path.join( + get_package_share_directory("whisper_bringup"), + "launch", + "silero-vad.launch.py", + ) + ), + launch_arguments={ + "enabled": LaunchConfiguration( + "vad_enabled", + default=PythonExpression([LaunchConfiguration("stream")]), + ), + "model_repo": silero_vad_model_repo, + "model_filename": silero_vad_model_filename, + "model_path": silero_vad_model_path, + }.items(), ), Node( package="audio_common", diff --git a/whisper_ros/CMakeLists.txt b/whisper_ros/CMakeLists.txt index 5a2bc13..78045b8 100644 --- a/whisper_ros/CMakeLists.txt +++ b/whisper_ros/CMakeLists.txt @@ -12,11 +12,18 @@ find_package(rclcpp_action REQUIRED) find_package(rclcpp_lifecycle REQUIRED) find_package(std_msgs REQUIRED) find_package(std_srvs REQUIRED) +find_package(audio_common_msgs REQUIRED) find_package(whisper_msgs REQUIRED) find_package(whisper_cpp_vendor REQUIRED) +find_package(onnxruntime_vendor REQUIRED) +find_library(PORTAUDIO_LIB portaudio REQUIRED) -include_directories(include) +include_directories( + include + ${PORTAUDIO_INCLUDE_DIR} +) +# whisper_node add_executable(whisper_node src/whisper_main.cpp src/whisper_ros/whisper_node.cpp @@ -36,6 +43,7 @@ ament_target_dependencies(whisper_node whisper_cpp_vendor ) +# whisper_server_node add_executable(whisper_server_node src/whisper_server_main.cpp src/whisper_ros/whisper_server_node.cpp @@ -56,10 +64,28 @@ ament_target_dependencies(whisper_server_node whisper_cpp_vendor ) -ament_export_dependencies(whisper_cpp_vendor) +# silero_vad_node +add_executable(silero_vad_node + src/silero_vad_main.cpp + src/silero_vad/silero_vad_node.cpp + src/silero_vad/vad_iterator.cpp + src/silero_vad/timestamp.cpp +) +target_link_libraries(silero_vad_node ${PORTAUDIO_LIB}) +ament_target_dependencies(silero_vad_node + rclcpp + rclcpp_lifecycle + std_msgs + std_srvs + audio_common_msgs + onnxruntime_vendor +) +# Export dependencies +ament_export_dependencies(whisper_cpp_vendor) +ament_export_dependencies(onnxruntime_vendor) -# INSTALL +# Install install(TARGETS whisper_node DESTINATION lib/${PROJECT_NAME} @@ -70,10 +96,9 @@ install(TARGETS DESTINATION lib/${PROJECT_NAME} ) -install(PROGRAMS - whisper_ros/silero_vad_node.py +install(TARGETS + silero_vad_node DESTINATION lib/${PROJECT_NAME} - RENAME silero_vad_node ) ament_package() diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp new file mode 100644 index 0000000..4c9f376 --- /dev/null +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -0,0 +1,126 @@ +// MIT License + +// Copyright (c) 2024 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef SILERO_VAD_NODE_HPP +#define SILERO_VAD_NODE_HPP + +#include +#include +#include +#include + +#include "audio_common_msgs/msg/audio_stamped.hpp" +#include "std_msgs/msg/float32_multi_array.hpp" +#include "std_srvs/srv/set_bool.hpp" + +#include "silero_vad/vad_iterator.hpp" + +namespace silero_vad { + +class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { + +public: + SileroVadNode(); + + rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn + on_configure(const rclcpp_lifecycle::State &); + rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn + on_activate(const rclcpp_lifecycle::State &); + rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn + on_deactivate(const rclcpp_lifecycle::State &); + rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn + on_cleanup(const rclcpp_lifecycle::State &); + rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn + on_shutdown(const rclcpp_lifecycle::State &); + +protected: + bool enabled; + std::atomic listening; + std::vector data; + std::unique_ptr vad_iterator; + +private: + std::string model_path_; + int sample_rate_; + int frame_size_ms_; + float threshold_; + int min_silence_ms_; + int speech_pad_ms_; + int min_speech_ms_; + float max_speech_s_; + + rclcpp::Publisher::SharedPtr publisher_; + rclcpp::Subscription::SharedPtr + subscription_; + + rclcpp::Service::SharedPtr enable_srv_; + + void + audio_callback(const audio_common_msgs::msg::AudioStamped::SharedPtr msg); + + void enable_cb(const std::shared_ptr request, + std::shared_ptr response); + + template + std::vector convert_to_float(const std::vector &input) { + static_assert(std::is_integral::value, + "Input type must be an integral type."); + + std::vector output; + output.reserve(input.size()); + + if constexpr (std::is_same::value) { + // uint8_t data normalized to [0.0, 1.0] + for (T value : input) { + output.push_back(static_cast(value) / 255.0f); + } + } else if constexpr (std::is_same::value) { + // int8_t data normalized to [-1.0, 1.0] + constexpr float scale = 1.0f / 127.0f; // Max positive value of int8_t + for (T value : input) { + output.push_back(static_cast(value) * scale); + } + } else if constexpr (std::is_same::value) { + // int16_t data normalized to [-1.0, 1.0] + constexpr float scale = 1.0f / 32767.0f; // Max positive value of int16_t + for (T value : input) { + output.push_back(static_cast(value) * scale); + } + } else if constexpr (std::is_same::value) { + // int32_t data normalized to [-1.0, 1.0] + constexpr float scale = + 1.0f / 2147483647.0f; // Max positive value of int32_t + for (T value : input) { + output.push_back(static_cast(value) * scale); + } + } else { + throw std::invalid_argument( + "Unsupported data type for audio conversion."); + } + + return output; + } +}; + +} // namespace silero_vad + +#endif \ No newline at end of file diff --git a/whisper_ros/include/silero_vad/timestamp.hpp b/whisper_ros/include/silero_vad/timestamp.hpp new file mode 100644 index 0000000..030a8bc --- /dev/null +++ b/whisper_ros/include/silero_vad/timestamp.hpp @@ -0,0 +1,46 @@ +// MIT License + +// Copyright (c) 2024 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef TIMESTAMPT_HPP +#define TIMESTAMPT_HPP + +#include + +namespace silero_vad { + +class Timestamp { +public: + int start; + int end; + float speech_prob; + + Timestamp(int start = -1, int end = -1, float speech_prob = 0); + + Timestamp &operator=(const Timestamp &other); + bool operator==(const Timestamp &other) const; + + std::string to_string() const; +}; + +} // namespace silero_vad + +#endif \ No newline at end of file diff --git a/whisper_ros/include/silero_vad/vad_iterator.hpp b/whisper_ros/include/silero_vad/vad_iterator.hpp new file mode 100644 index 0000000..61a9943 --- /dev/null +++ b/whisper_ros/include/silero_vad/vad_iterator.hpp @@ -0,0 +1,94 @@ +// MIT License + +// Copyright (c) 2024 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef VAD_ITERATOR_HPP +#define VAD_ITERATOR_HPP + +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" +#include "silero_vad/timestamp.hpp" + +namespace silero_vad { + +class VadIterator { + +public: + VadIterator(const std::string &model_path, int sample_rate = 16000, + int frame_size_ms = 32, float threshold = 0.5f, + int min_silence_ms = 0, int speech_pad_ms = 32, + int min_speech_ms = 32, + float max_speech_s = std::numeric_limits::infinity()); + + void reset_states(); + Timestamp predict(const std::vector &data); + +private: + Ort::Env env; + Ort::SessionOptions session_options; + std::shared_ptr session; + Ort::AllocatorWithDefaultOptions allocator; + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); + + // Model configuration + float threshold; + int sample_rate; + int sr_per_ms; + int64_t window_size_samples; + int min_speech_samples; + int speech_pad_samples; + float max_speech_samples; + unsigned int min_silence_samples; + unsigned int min_silence_samples_at_max_speech; + + // Model state + bool triggered = false; + unsigned int temp_end = 0; + unsigned int current_sample = 0; + int prev_end = 0; + int next_start = 0; + + std::vector ort_inputs; + std::vector input_node_names = {"input", "state", "sr"}; + + std::vector input; + std::vector state; + std::vector sr; + std::vector context; + + int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = {2, 1, 128}; + const int64_t sr_node_dims[1] = {1}; + + std::vector ort_outputs; + std::vector output_node_names = {"output", "stateN"}; + + void init_onnx_model(const std::string &model_path); +}; + +} // namespace silero_vad + +#endif \ No newline at end of file diff --git a/whisper_ros/package.xml b/whisper_ros/package.xml index 08e75b2..7fed619 100644 --- a/whisper_ros/package.xml +++ b/whisper_ros/package.xml @@ -21,6 +21,8 @@ audio_common_msgs whisper_msgs whisper_cpp_vendor + onnxruntime_vendor + portaudio19-dev ament_cmake diff --git a/whisper_ros/src/silero_vad/silero_vad_node.cpp b/whisper_ros/src/silero_vad/silero_vad_node.cpp new file mode 100644 index 0000000..b56b4d5 --- /dev/null +++ b/whisper_ros/src/silero_vad/silero_vad_node.cpp @@ -0,0 +1,250 @@ +// MIT License + +// Copyright (c) 2024 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include + +#include + +#include "silero_vad/silero_vad_node.hpp" + +using namespace silero_vad; +using std::placeholders::_1; +using std::placeholders::_2; + +SileroVadNode::SileroVadNode() + : rclcpp_lifecycle::LifecycleNode("silero_vad_node"), listening(false) { + + this->declare_parameter("enabled", true); + this->declare_parameter("model_path", ""); + this->declare_parameter("sample_rate", 16000); + this->declare_parameter("frame_size_ms", 32); + this->declare_parameter("threshold", 0.5f); + this->declare_parameter("min_silence_ms", 100); + this->declare_parameter("speech_pad_ms", 30); + this->declare_parameter("min_speech_ms", 32); + this->declare_parameter("max_speech_s", + std::numeric_limits::infinity()); +} + +rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn +SileroVadNode::on_configure(const rclcpp_lifecycle::State &) { + + RCLCPP_INFO(get_logger(), "[%s] Configuring...", this->get_name()); + + // get params + this->get_parameter("enabled", this->enabled); + this->get_parameter("model_path", this->model_path_); + this->get_parameter("sample_rate", this->sample_rate_); + this->get_parameter("frame_size_ms", this->frame_size_ms_); + this->get_parameter("threshold", this->threshold_); + this->get_parameter("min_silence_ms", this->min_silence_ms_); + this->get_parameter("speech_pad_ms", this->speech_pad_ms_); + this->get_parameter("min_speech_ms", this->min_speech_ms_); + this->get_parameter("max_speech_s", this->max_speech_s_); + + RCLCPP_INFO(get_logger(), "[%s] Configured", this->get_name()); + + return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface:: + CallbackReturn::SUCCESS; +} + +rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn +SileroVadNode::on_activate(const rclcpp_lifecycle::State &) { + + RCLCPP_INFO(get_logger(), "[%s] Activating...", this->get_name()); + + // create silero-vad + this->vad_iterator = std::make_unique( + this->model_path_, this->sample_rate_, this->frame_size_ms_, + this->threshold_, this->min_silence_ms_, this->speech_pad_ms_, + this->min_speech_ms_, this->max_speech_s_); + + this->publisher_ = + this->create_publisher("vad", 10); + this->subscription_ = + this->create_subscription( + "audio", rclcpp::SensorDataQoS(), + std::bind(&SileroVadNode::audio_callback, this, _1)); + + this->enable_srv_ = this->create_service( + "enable_vad", std::bind(&SileroVadNode::enable_cb, this, _1, _2)); + + RCLCPP_INFO(get_logger(), "[%s] Activated", this->get_name()); + + return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface:: + CallbackReturn::SUCCESS; +} + +rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn +SileroVadNode::on_deactivate(const rclcpp_lifecycle::State &) { + + RCLCPP_INFO(get_logger(), "[%s] Deactivating...", this->get_name()); + + // reset silero + this->vad_iterator->reset_states(); + + this->publisher_.reset(); + this->publisher_ = nullptr; + + this->subscription_.reset(); + this->subscription_ = nullptr; + + this->enable_srv_.reset(); + this->enable_srv_ = nullptr; + + RCLCPP_INFO(get_logger(), "[%s] Deactivated", this->get_name()); + + return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface:: + CallbackReturn::SUCCESS; +} + +rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn +SileroVadNode::on_cleanup(const rclcpp_lifecycle::State &) { + + RCLCPP_INFO(get_logger(), "[%s] Cleaning up...", this->get_name()); + RCLCPP_INFO(get_logger(), "[%s] Cleaned up", this->get_name()); + + return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface:: + CallbackReturn::SUCCESS; +} + +rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn +SileroVadNode::on_shutdown(const rclcpp_lifecycle::State &) { + + RCLCPP_INFO(get_logger(), "[%s] Shutting down...", this->get_name()); + RCLCPP_INFO(get_logger(), "[%s] Shutted down", this->get_name()); + + return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface:: + CallbackReturn::SUCCESS; +} + +void SileroVadNode::audio_callback( + const audio_common_msgs::msg::AudioStamped::SharedPtr msg) { + + if (!this->enabled) { + return; + } + + std::vector data; + + try { + switch (msg->audio.info.format) { + case paFloat32: { + data = msg->audio.audio_data.float32_data; + break; + } + case paInt32: { + data = this->convert_to_float(msg->audio.audio_data.int32_data); + break; + } + case paInt16: { + data = this->convert_to_float(msg->audio.audio_data.int16_data); + break; + } + case paInt8: { + data = this->convert_to_float(msg->audio.audio_data.int8_data); + break; + } + case paUInt8: { + data = this->convert_to_float(msg->audio.audio_data.uint8_data); + break; + } + default: + RCLCPP_ERROR(this->get_logger(), "Unsupported format"); + return; + } + + } catch (const std::exception &e) { + RCLCPP_ERROR(this->get_logger(), "Error while processing audio data: %s", + e.what()); + return; + } + + // Predict if speech starts or ends + auto timestamp = this->vad_iterator->predict(data); + // RCLCPP_INFO(this->get_logger(), "Timestampt: %s", + // timestamp.to_string().c_str()); + + // Check if speech starts + if (timestamp.start != -1 && timestamp.end == -1 && !this->listening) { + RCLCPP_INFO(this->get_logger(), "Speech starts..."); + this->listening.store(true); + this->data.clear(); + } + + // Add audio if listening + if (this->listening) { + for (auto d : data) { + this->data.push_back(d); + } + } + + // Check if speech ends + if (timestamp.start == -1 && timestamp.end != -1 && this->listening) { + RCLCPP_INFO(this->get_logger(), "Speech ends..."); + + if (this->data.size() / msg->audio.info.rate < 1.0) { + int pad_size = + msg->audio.info.chunk + msg->audio.info.rate - this->data.size(); + for (int i = 0; i < pad_size; i++) { + this->data.push_back(0.0); + } + } + + this->listening.store(false); + auto vad_msg = std_msgs::msg::Float32MultiArray(); + vad_msg.data = this->data; + this->publisher_->publish(vad_msg); + this->data.clear(); + } +} + +void SileroVadNode::enable_cb( + const std::shared_ptr request, + std::shared_ptr response) { + + response->success = true; + + if (request->data && !this->enabled) { + response->message = "SileroVAD enabled"; + this->enabled = true; + this->data.clear(); + this->vad_iterator->reset_states(); + + } else if (request->data && this->enabled) { + response->message = "SileroVAD already enabled"; + + } else if (!request->data && this->enabled) { + response->message = "SileroVAD disabled"; + this->listening.store(false); + this->data.clear(); + + } else if (!request->data && !this->enabled) { + response->message = "SileroVAD already disabled"; + } + + RCLCPP_INFO(this->get_logger(), response->message.c_str()); +} diff --git a/whisper_ros/src/silero_vad/timestamp.cpp b/whisper_ros/src/silero_vad/timestamp.cpp new file mode 100644 index 0000000..427a2d3 --- /dev/null +++ b/whisper_ros/src/silero_vad/timestamp.cpp @@ -0,0 +1,46 @@ +// MIT License + +// Copyright (c) 2023 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "silero_vad/timestamp.hpp" + +using namespace silero_vad; + +Timestamp::Timestamp(int start, int end, float speech_prob) + : start(start), end(end), speech_prob(speech_prob) {} + +Timestamp &Timestamp::operator=(const Timestamp &other) { + this->start = other.start; + this->end = other.end; + this->speech_prob = other.speech_prob; + return *this; +} + +bool Timestamp::operator==(const Timestamp &other) const { + return this->start == other.start && this->end == other.end; +} + +std::string Timestamp::to_string() const { + char buffer[256]; + snprintf(buffer, sizeof(buffer), "{start:%08d,end:%08d,prob:%f}", this->start, + this->end, this->speech_prob * 100); + return std::string(buffer); +} \ No newline at end of file diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp new file mode 100644 index 0000000..27853d7 --- /dev/null +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -0,0 +1,144 @@ +// MIT License + +// Copyright (c) 2023 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include + +#include "silero_vad/vad_iterator.hpp" + +using namespace silero_vad; + +VadIterator::VadIterator(const std::string &model_path, int sample_rate, + int frame_size_ms, float threshold, int min_silence_ms, + int speech_pad_ms, int min_speech_ms, + float max_speech_s) + : env(ORT_LOGGING_LEVEL_WARNING, "VadIterator"), threshold(threshold), + sample_rate(sample_rate), sr_per_ms(sample_rate / 1000), + window_size_samples(frame_size_ms * sr_per_ms), + min_speech_samples(sr_per_ms * min_speech_ms), + speech_pad_samples(sr_per_ms * speech_pad_ms), + max_speech_samples(sample_rate * max_speech_s - window_size_samples - + 2 * speech_pad_samples), + min_silence_samples(sr_per_ms * min_silence_ms), + min_silence_samples_at_max_speech(sr_per_ms * 98), + state(2 * 1 * 128, 0.0f), sr(1, sample_rate), context(64, 0.0f) { + + // this->input.resize(window_size_samples); + this->input_node_dims[0] = 1; + this->input_node_dims[1] = window_size_samples; + this->init_onnx_model(model_path); +} + +void VadIterator::init_onnx_model(const std::string &model_path) { + this->session_options.SetIntraOpNumThreads(1); + this->session_options.SetInterOpNumThreads(1); + this->session_options.SetGraphOptimizationLevel( + GraphOptimizationLevel::ORT_ENABLE_ALL); + this->session = std::make_shared(this->env, model_path.c_str(), + this->session_options); +} + +void VadIterator::reset_states() { + std::fill(this->state.begin(), this->state.end(), 0.0f); + std::fill(this->context.begin(), this->context.end(), 0.0f); + this->triggered = false; + this->temp_end = 0; + this->current_sample = 0; +} + +Timestamp VadIterator::predict(const std::vector &data) { + // Create input tensors + this->input.clear(); + for (auto ele : this->context) { + this->input.push_back(ele); + } + + for (auto ele : data) { + this->input.push_back(ele); + } + + Ort::Value input_tensor = Ort::Value::CreateTensor( + this->memory_info, this->input.data(), this->input.size(), + this->input_node_dims, 2); + Ort::Value state_tensor = Ort::Value::CreateTensor( + this->memory_info, this->state.data(), this->state.size(), + this->state_node_dims, 3); + Ort::Value sr_tensor = + Ort::Value::CreateTensor(this->memory_info, this->sr.data(), + this->sr.size(), this->sr_node_dims, 1); + + // Clear and add inputs + this->ort_inputs.clear(); + this->ort_inputs.emplace_back(std::move(input_tensor)); + this->ort_inputs.emplace_back(std::move(state_tensor)); + this->ort_inputs.emplace_back(std::move(sr_tensor)); + + // Run inference + this->ort_outputs = this->session->Run( + Ort::RunOptions{nullptr}, this->input_node_names.data(), + this->ort_inputs.data(), this->ort_inputs.size(), + this->output_node_names.data(), this->output_node_names.size()); + + // Process output + float speech_prob = this->ort_outputs[0].GetTensorMutableData()[0]; + float *updated_state = this->ort_outputs[1].GetTensorMutableData(); + std::copy(updated_state, updated_state + this->state.size(), + this->state.begin()); + + for (int i = 63; i >= 0; i--) { + this->context.push_back(data.at(data.size() - i)); + } + + // Handle result + this->current_sample += this->window_size_samples; + + if (speech_prob >= this->threshold) { + if (this->temp_end != 0) { + this->temp_end = 0; + } + + if (!this->triggered) { + this->triggered = true; + return Timestamp(this->current_sample - this->speech_pad_samples - + this->window_size_samples, + -1, speech_prob); + } + + } else if (speech_prob < this->threshold - 0.15 && this->triggered) { + if (this->temp_end == 0) { + this->temp_end = this->current_sample; + } + + if (this->current_sample - this->temp_end >= this->min_silence_samples) { + this->temp_end = 0; + this->triggered = false; + return Timestamp(-1, + this->temp_end + this->speech_pad_samples - + this->window_size_samples, + speech_prob); + } + } + + return Timestamp(-1, -1, speech_prob); +} diff --git a/whisper_ros/src/silero_vad_main.cpp b/whisper_ros/src/silero_vad_main.cpp new file mode 100644 index 0000000..982ef01 --- /dev/null +++ b/whisper_ros/src/silero_vad_main.cpp @@ -0,0 +1,40 @@ +// MIT License + +// Copyright (c) 2023 Miguel Ángel González Santamarta + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "silero_vad/silero_vad_node.hpp" + +using namespace silero_vad; + +int main(int argc, char *argv[]) { + rclcpp::init(argc, argv); + + auto node = std::make_shared(); + node->configure(); + node->activate(); + + rclcpp::executors::SingleThreadedExecutor executor; + executor.add_node(node->get_node_base_interface()); + executor.spin(); + + rclcpp::shutdown(); + return 0; +} \ No newline at end of file diff --git a/whisper_ros/src/whisper_main.cpp b/whisper_ros/src/whisper_main.cpp index 6761459..55f0f53 100644 --- a/whisper_ros/src/whisper_main.cpp +++ b/whisper_ros/src/whisper_main.cpp @@ -20,8 +20,6 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - #include "whisper_ros/whisper_node.hpp" using namespace whisper_ros; diff --git a/whisper_ros/src/whisper_ros/whisper_base_node.cpp b/whisper_ros/src/whisper_ros/whisper_base_node.cpp index fa38a4b..907d023 100644 --- a/whisper_ros/src/whisper_ros/whisper_base_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_base_node.cpp @@ -20,10 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - -#include "whisper.h" #include "whisper_ros/whisper_base_node.hpp" +#include "whisper.h" using namespace whisper_ros; using std::placeholders::_1; diff --git a/whisper_ros/src/whisper_ros/whisper_node.cpp b/whisper_ros/src/whisper_ros/whisper_node.cpp index 2c3ec9d..ecc1f7f 100644 --- a/whisper_ros/src/whisper_ros/whisper_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_node.cpp @@ -20,10 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - -#include "whisper.h" #include "whisper_ros/whisper_node.hpp" +#include "whisper.h" using namespace whisper_ros; using std::placeholders::_1; diff --git a/whisper_ros/src/whisper_ros/whisper_server_node.cpp b/whisper_ros/src/whisper_ros/whisper_server_node.cpp index 720cb43..4339df4 100644 --- a/whisper_ros/src/whisper_ros/whisper_server_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_server_node.cpp @@ -21,7 +21,6 @@ // SOFTWARE. #include -#include #include "whisper.h" #include "whisper_ros/whisper_server_node.hpp" diff --git a/whisper_ros/src/whisper_server_main.cpp b/whisper_ros/src/whisper_server_main.cpp index 8a5ea46..14627e2 100644 --- a/whisper_ros/src/whisper_server_main.cpp +++ b/whisper_ros/src/whisper_server_main.cpp @@ -20,8 +20,6 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - #include "whisper_ros/whisper_server_node.hpp" using namespace whisper_ros; From fef6b9e31f6714537a4cd5f199a853e8c810acd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 16:48:14 +0100 Subject: [PATCH 02/18] minor fixes for compile onnxruntime and run silero --- onnxruntime_vendor/CMakeLists.txt | 2 +- whisper_ros/src/silero_vad/vad_iterator.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime_vendor/CMakeLists.txt b/onnxruntime_vendor/CMakeLists.txt index a667079..2c74212 100644 --- a/onnxruntime_vendor/CMakeLists.txt +++ b/onnxruntime_vendor/CMakeLists.txt @@ -30,7 +30,7 @@ install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/lib) install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/include) # Export the onnxruntime library for downstream packages -ament_export_include_directories(${ONNXRUNTIME_INSTALL_DIR}/include) +ament_export_include_directories(include) ament_export_libraries(onnxruntime) # Export the package diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index 27853d7..5500c78 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -106,7 +106,7 @@ Timestamp VadIterator::predict(const std::vector &data) { std::copy(updated_state, updated_state + this->state.size(), this->state.begin()); - for (int i = 63; i >= 0; i--) { + for (int i = 64; i > 0; i--) { this->context.push_back(data.at(data.size() - i)); } From dd32b5eac641febbe40230fb09446e33086ab2e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 16:53:07 +0100 Subject: [PATCH 03/18] DESTINATION added to install --- onnxruntime_vendor/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime_vendor/CMakeLists.txt b/onnxruntime_vendor/CMakeLists.txt index 2c74212..a3ddd69 100644 --- a/onnxruntime_vendor/CMakeLists.txt +++ b/onnxruntime_vendor/CMakeLists.txt @@ -26,8 +26,8 @@ add_custom_target(download_onnxruntime ALL ) # Install the ONNX Runtime library and include files -install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/lib) -install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/include) +install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/lib DESTINATION .) +install(DIRECTORY ${ONNXRUNTIME_INSTALL_DIR}/include DESTINATION .) # Export the onnxruntime library for downstream packages ament_export_include_directories(include) From 870c6558003ad93feb3d20946956c3e1b73bd28d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 18:37:35 +0100 Subject: [PATCH 04/18] removing unused params form vad_iterator --- whisper_bringup/launch/silero-vad.launch.py | 8 +- .../include/silero_vad/silero_vad_node.hpp | 2 - .../include/silero_vad/vad_iterator.hpp | 10 +-- .../src/silero_vad/silero_vad_node.cpp | 14 +--- whisper_ros/src/silero_vad/timestamp.cpp | 2 +- whisper_ros/src/silero_vad/vad_iterator.cpp | 77 ++++++++++--------- 6 files changed, 50 insertions(+), 63 deletions(-) diff --git a/whisper_bringup/launch/silero-vad.launch.py b/whisper_bringup/launch/silero-vad.launch.py index 9f08d6f..c13633a 100644 --- a/whisper_bringup/launch/silero-vad.launch.py +++ b/whisper_bringup/launch/silero-vad.launch.py @@ -54,13 +54,9 @@ def run_silero_vad(context: LaunchContext, repo, file, model_path): "frame_size_ms": LaunchConfiguration("frame_size_ms", default=32), "threshold": LaunchConfiguration("threshold", default=0.5), "min_silence_ms": LaunchConfiguration( - "min_silence_ms", default=0 - ), - "speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=32), - "min_speech_ms": LaunchConfiguration("min_speech_ms", default=32), - "max_speech_s": LaunchConfiguration( - "max_speech_s", default=float("inf") + "min_silence_ms", default=100 ), + "speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=30), } ], remappings=[("audio", "/audio/in")], diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp index 4c9f376..7cb6a82 100644 --- a/whisper_ros/include/silero_vad/silero_vad_node.hpp +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -65,8 +65,6 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { float threshold_; int min_silence_ms_; int speech_pad_ms_; - int min_speech_ms_; - float max_speech_s_; rclcpp::Publisher::SharedPtr publisher_; rclcpp::Subscription::SharedPtr diff --git a/whisper_ros/include/silero_vad/vad_iterator.hpp b/whisper_ros/include/silero_vad/vad_iterator.hpp index 61a9943..b6f2911 100644 --- a/whisper_ros/include/silero_vad/vad_iterator.hpp +++ b/whisper_ros/include/silero_vad/vad_iterator.hpp @@ -38,9 +38,7 @@ class VadIterator { public: VadIterator(const std::string &model_path, int sample_rate = 16000, int frame_size_ms = 32, float threshold = 0.5f, - int min_silence_ms = 0, int speech_pad_ms = 32, - int min_speech_ms = 32, - float max_speech_s = std::numeric_limits::infinity()); + int min_silence_ms = 100, int speech_pad_ms = 30); void reset_states(); Timestamp predict(const std::vector &data); @@ -58,11 +56,9 @@ class VadIterator { int sample_rate; int sr_per_ms; int64_t window_size_samples; - int min_speech_samples; int speech_pad_samples; - float max_speech_samples; unsigned int min_silence_samples; - unsigned int min_silence_samples_at_max_speech; + int context_size; // Model state bool triggered = false; @@ -75,9 +71,9 @@ class VadIterator { std::vector input_node_names = {"input", "state", "sr"}; std::vector input; + std::vector context; std::vector state; std::vector sr; - std::vector context; int64_t input_node_dims[2] = {}; const int64_t state_node_dims[3] = {2, 1, 128}; diff --git a/whisper_ros/src/silero_vad/silero_vad_node.cpp b/whisper_ros/src/silero_vad/silero_vad_node.cpp index b56b4d5..e04721e 100644 --- a/whisper_ros/src/silero_vad/silero_vad_node.cpp +++ b/whisper_ros/src/silero_vad/silero_vad_node.cpp @@ -44,9 +44,6 @@ SileroVadNode::SileroVadNode() this->declare_parameter("threshold", 0.5f); this->declare_parameter("min_silence_ms", 100); this->declare_parameter("speech_pad_ms", 30); - this->declare_parameter("min_speech_ms", 32); - this->declare_parameter("max_speech_s", - std::numeric_limits::infinity()); } rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn @@ -62,8 +59,6 @@ SileroVadNode::on_configure(const rclcpp_lifecycle::State &) { this->get_parameter("threshold", this->threshold_); this->get_parameter("min_silence_ms", this->min_silence_ms_); this->get_parameter("speech_pad_ms", this->speech_pad_ms_); - this->get_parameter("min_speech_ms", this->min_speech_ms_); - this->get_parameter("max_speech_s", this->max_speech_s_); RCLCPP_INFO(get_logger(), "[%s] Configured", this->get_name()); @@ -79,8 +74,7 @@ SileroVadNode::on_activate(const rclcpp_lifecycle::State &) { // create silero-vad this->vad_iterator = std::make_unique( this->model_path_, this->sample_rate_, this->frame_size_ms_, - this->threshold_, this->min_silence_ms_, this->speech_pad_ms_, - this->min_speech_ms_, this->max_speech_s_); + this->threshold_, this->min_silence_ms_, this->speech_pad_ms_); this->publisher_ = this->create_publisher("vad", 10); @@ -185,8 +179,6 @@ void SileroVadNode::audio_callback( // Predict if speech starts or ends auto timestamp = this->vad_iterator->predict(data); - // RCLCPP_INFO(this->get_logger(), "Timestampt: %s", - // timestamp.to_string().c_str()); // Check if speech starts if (timestamp.start != -1 && timestamp.end == -1 && !this->listening) { @@ -209,9 +201,7 @@ void SileroVadNode::audio_callback( if (this->data.size() / msg->audio.info.rate < 1.0) { int pad_size = msg->audio.info.chunk + msg->audio.info.rate - this->data.size(); - for (int i = 0; i < pad_size; i++) { - this->data.push_back(0.0); - } + this->data.insert(this->data.end(), pad_size, 0.0f); } this->listening.store(false); diff --git a/whisper_ros/src/silero_vad/timestamp.cpp b/whisper_ros/src/silero_vad/timestamp.cpp index 427a2d3..5ed9f67 100644 --- a/whisper_ros/src/silero_vad/timestamp.cpp +++ b/whisper_ros/src/silero_vad/timestamp.cpp @@ -1,6 +1,6 @@ // MIT License -// Copyright (c) 2023 Miguel Ángel González Santamarta +// Copyright (c) 2024 Miguel Ángel González Santamarta // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index 5500c78..323f05f 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -1,6 +1,6 @@ // MIT License -// Copyright (c) 2023 Miguel Ángel González Santamarta +// Copyright (c) 2024 Miguel Ángel González Santamarta // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,6 +20,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#include #include #include #include @@ -31,23 +32,24 @@ using namespace silero_vad; VadIterator::VadIterator(const std::string &model_path, int sample_rate, int frame_size_ms, float threshold, int min_silence_ms, - int speech_pad_ms, int min_speech_ms, - float max_speech_s) + int speech_pad_ms) : env(ORT_LOGGING_LEVEL_WARNING, "VadIterator"), threshold(threshold), sample_rate(sample_rate), sr_per_ms(sample_rate / 1000), window_size_samples(frame_size_ms * sr_per_ms), - min_speech_samples(sr_per_ms * min_speech_ms), speech_pad_samples(sr_per_ms * speech_pad_ms), - max_speech_samples(sample_rate * max_speech_s - window_size_samples - - 2 * speech_pad_samples), min_silence_samples(sr_per_ms * min_silence_ms), - min_silence_samples_at_max_speech(sr_per_ms * 98), - state(2 * 1 * 128, 0.0f), sr(1, sample_rate), context(64, 0.0f) { + context_size(sample_rate == 16000 ? 64 : 32), context(context_size, 0.0f), + state(2 * 1 * 128, 0.0f), sr(1, sample_rate) { - // this->input.resize(window_size_samples); this->input_node_dims[0] = 1; this->input_node_dims[1] = window_size_samples; - this->init_onnx_model(model_path); + + try { + this->init_onnx_model(model_path); + } catch (const std::exception &e) { + throw std::runtime_error("Failed to initialize ONNX model: " + + std::string(e.what())); + } } void VadIterator::init_onnx_model(const std::string &model_path) { @@ -55,8 +57,14 @@ void VadIterator::init_onnx_model(const std::string &model_path) { this->session_options.SetInterOpNumThreads(1); this->session_options.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL); - this->session = std::make_shared(this->env, model_path.c_str(), - this->session_options); + + try { + this->session = std::make_shared( + this->env, model_path.c_str(), this->session_options); + } catch (const std::exception &e) { + throw std::runtime_error("Failed to create ONNX session: " + + std::string(e.what())); + } } void VadIterator::reset_states() { @@ -68,16 +76,13 @@ void VadIterator::reset_states() { } Timestamp VadIterator::predict(const std::vector &data) { - // Create input tensors + // Pre-fill input with context this->input.clear(); - for (auto ele : this->context) { - this->input.push_back(ele); - } - - for (auto ele : data) { - this->input.push_back(ele); - } + this->input.reserve(context.size() + data.size()); + this->input.insert(input.end(), context.begin(), context.end()); + this->input.insert(input.end(), data.begin(), data.end()); + // Create input tensors Ort::Value input_tensor = Ort::Value::CreateTensor( this->memory_info, this->input.data(), this->input.size(), this->input_node_dims, 2); @@ -95,10 +100,14 @@ Timestamp VadIterator::predict(const std::vector &data) { this->ort_inputs.emplace_back(std::move(sr_tensor)); // Run inference - this->ort_outputs = this->session->Run( - Ort::RunOptions{nullptr}, this->input_node_names.data(), - this->ort_inputs.data(), this->ort_inputs.size(), - this->output_node_names.data(), this->output_node_names.size()); + try { + this->ort_outputs = session->Run( + Ort::RunOptions{nullptr}, this->input_node_names.data(), + this->ort_inputs.data(), this->ort_inputs.size(), + this->output_node_names.data(), this->output_node_names.size()); + } catch (const std::exception &e) { + throw std::runtime_error("ONNX inference failed: " + std::string(e.what())); + } // Process output float speech_prob = this->ort_outputs[0].GetTensorMutableData()[0]; @@ -106,9 +115,8 @@ Timestamp VadIterator::predict(const std::vector &data) { std::copy(updated_state, updated_state + this->state.size(), this->state.begin()); - for (int i = 64; i > 0; i--) { - this->context.push_back(data.at(data.size() - i)); - } + // Update context with the last 64 samples of data + this->context.assign(data.end() - context_size, data.end()); // Handle result this->current_sample += this->window_size_samples; @@ -119,10 +127,10 @@ Timestamp VadIterator::predict(const std::vector &data) { } if (!this->triggered) { + int start_timestwamp = this->current_sample - this->speech_pad_samples - + this->window_size_samples; this->triggered = true; - return Timestamp(this->current_sample - this->speech_pad_samples - - this->window_size_samples, - -1, speech_prob); + return Timestamp(start_timestwamp, -1, speech_prob); } } else if (speech_prob < this->threshold - 0.15 && this->triggered) { @@ -131,12 +139,11 @@ Timestamp VadIterator::predict(const std::vector &data) { } if (this->current_sample - this->temp_end >= this->min_silence_samples) { - this->temp_end = 0; + int end_timestamp = + this->temp_end + this->speech_pad_samples - this->window_size_samples; this->triggered = false; - return Timestamp(-1, - this->temp_end + this->speech_pad_samples - - this->window_size_samples, - speech_prob); + this->temp_end = 0; + return Timestamp(-1, end_timestamp, speech_prob); } } From d49c721d5bb6138b8cbbcb1c05a86ddb55c0be1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 21:05:35 +0100 Subject: [PATCH 05/18] new model repo --- whisper_bringup/launch/silero-vad.launch.py | 6 +++--- whisper_bringup/launch/whisper.launch.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/whisper_bringup/launch/silero-vad.launch.py b/whisper_bringup/launch/silero-vad.launch.py index c13633a..786916e 100644 --- a/whisper_bringup/launch/silero-vad.launch.py +++ b/whisper_bringup/launch/silero-vad.launch.py @@ -54,9 +54,9 @@ def run_silero_vad(context: LaunchContext, repo, file, model_path): "frame_size_ms": LaunchConfiguration("frame_size_ms", default=32), "threshold": LaunchConfiguration("threshold", default=0.5), "min_silence_ms": LaunchConfiguration( - "min_silence_ms", default=100 + "min_silence_ms", default=128 ), - "speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=30), + "speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=32), } ], remappings=[("audio", "/audio/in")], @@ -66,7 +66,7 @@ def run_silero_vad(context: LaunchContext, repo, file, model_path): model_repo = LaunchConfiguration("model_repo") model_repo_cmd = DeclareLaunchArgument( "model_repo", - default_value="deepghs/silero-vad-onnx", + default_value="mgonzs13/silero-vad-onnx", description="Hugging Face model repo", ) diff --git a/whisper_bringup/launch/whisper.launch.py b/whisper_bringup/launch/whisper.launch.py index d38b555..b925ece 100644 --- a/whisper_bringup/launch/whisper.launch.py +++ b/whisper_bringup/launch/whisper.launch.py @@ -149,14 +149,14 @@ def run_whisper(context: LaunchContext, repo, file, model_path): silero_vad_model_repo = LaunchConfiguration("silero_vad_model_repo") silero_vad_model_repo_cmd = DeclareLaunchArgument( "silero_vad_model_repo", - default_value="onnx-community/silero-vad", + default_value="mgonzs13/silero-vad-onnx", description="Hugging Face model repo for SileroVAD", ) silero_vad_model_filename = LaunchConfiguration("silero_vad_model_filename") silero_vad_model_filename_cmd = DeclareLaunchArgument( "silero_vad_model_filename", - default_value="onnx/model.onnx", + default_value="silero_vad.onnx", description="Hugging Face model filename for SileroVAD", ) From d99e2356af7cf1406681ccde54f73e99ef206b96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 21:07:57 +0100 Subject: [PATCH 06/18] minor fixes --- whisper_ros/src/silero_vad/silero_vad_node.cpp | 4 +--- whisper_ros/src/silero_vad/vad_iterator.cpp | 12 +++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/whisper_ros/src/silero_vad/silero_vad_node.cpp b/whisper_ros/src/silero_vad/silero_vad_node.cpp index e04721e..7a541d1 100644 --- a/whisper_ros/src/silero_vad/silero_vad_node.cpp +++ b/whisper_ros/src/silero_vad/silero_vad_node.cpp @@ -189,9 +189,7 @@ void SileroVadNode::audio_callback( // Add audio if listening if (this->listening) { - for (auto d : data) { - this->data.push_back(d); - } + this->data.insert(this->data.end(), data.begin(), data.end()); } // Check if speech ends diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index 323f05f..bcd0061 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -42,7 +42,8 @@ VadIterator::VadIterator(const std::string &model_path, int sample_rate, state(2 * 1 * 128, 0.0f), sr(1, sample_rate) { this->input_node_dims[0] = 1; - this->input_node_dims[1] = window_size_samples; + this->input_node_dims[1] = this->window_size_samples; + this->input.reserve(context_size + this->window_size_samples); try { this->init_onnx_model(model_path); @@ -78,9 +79,9 @@ void VadIterator::reset_states() { Timestamp VadIterator::predict(const std::vector &data) { // Pre-fill input with context this->input.clear(); - this->input.reserve(context.size() + data.size()); - this->input.insert(input.end(), context.begin(), context.end()); - this->input.insert(input.end(), data.begin(), data.end()); + this->input.insert(this->input.end(), this->context.begin(), + this->context.end()); + this->input.insert(this->input.end(), data.begin(), data.end()); // Create input tensors Ort::Value input_tensor = Ort::Value::CreateTensor( @@ -132,8 +133,9 @@ Timestamp VadIterator::predict(const std::vector &data) { this->triggered = true; return Timestamp(start_timestwamp, -1, speech_prob); } + } - } else if (speech_prob < this->threshold - 0.15 && this->triggered) { + if (speech_prob < this->threshold - 0.15 && this->triggered) { if (this->temp_end == 0) { this->temp_end = this->current_sample; } From 651909834ec0572a67aa96532665f78d71746038 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 22:14:32 +0100 Subject: [PATCH 07/18] input_node_dims fixed --- whisper_ros/src/silero_vad/vad_iterator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index bcd0061..04fca36 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -42,8 +42,8 @@ VadIterator::VadIterator(const std::string &model_path, int sample_rate, state(2 * 1 * 128, 0.0f), sr(1, sample_rate) { this->input_node_dims[0] = 1; - this->input_node_dims[1] = this->window_size_samples; - this->input.reserve(context_size + this->window_size_samples); + this->input_node_dims[1] = this->context_size + this->window_size_samples; + this->input.reserve(this->context_size + this->window_size_samples); try { this->init_onnx_model(model_path); From 8c6cd54e87d92d25cedee0ca13876e804240bc04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 22:15:16 +0100 Subject: [PATCH 08/18] add the previous and the next chunks to publish --- .../include/silero_vad/silero_vad_node.hpp | 4 ++- .../src/silero_vad/silero_vad_node.cpp | 36 +++++++++++++++---- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp index 7cb6a82..3ada436 100644 --- a/whisper_ros/include/silero_vad/silero_vad_node.hpp +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -53,12 +53,14 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { on_shutdown(const rclcpp_lifecycle::State &); protected: - bool enabled; + std::atomic enabled; std::atomic listening; + std::atomic publish; std::vector data; std::unique_ptr vad_iterator; private: + std::vector prev_data; std::string model_path_; int sample_rate_; int frame_size_ms_; diff --git a/whisper_ros/src/silero_vad/silero_vad_node.cpp b/whisper_ros/src/silero_vad/silero_vad_node.cpp index 7a541d1..b72e2ba 100644 --- a/whisper_ros/src/silero_vad/silero_vad_node.cpp +++ b/whisper_ros/src/silero_vad/silero_vad_node.cpp @@ -35,7 +35,8 @@ using std::placeholders::_1; using std::placeholders::_2; SileroVadNode::SileroVadNode() - : rclcpp_lifecycle::LifecycleNode("silero_vad_node"), listening(false) { + : rclcpp_lifecycle::LifecycleNode("silero_vad_node"), listening(false), + publish(false) { this->declare_parameter("enabled", true); this->declare_parameter("model_path", ""); @@ -52,7 +53,10 @@ SileroVadNode::on_configure(const rclcpp_lifecycle::State &) { RCLCPP_INFO(get_logger(), "[%s] Configuring...", this->get_name()); // get params - this->get_parameter("enabled", this->enabled); + bool enabled; + this->get_parameter("enabled", enabled); + this->enabled.store(enabled); + this->get_parameter("model_path", this->model_path_); this->get_parameter("sample_rate", this->sample_rate_); this->get_parameter("frame_size_ms", this->frame_size_ms_); @@ -185,6 +189,11 @@ void SileroVadNode::audio_callback( RCLCPP_INFO(this->get_logger(), "Speech starts..."); this->listening.store(true); this->data.clear(); + + if (this->prev_data.size()) { + this->data.insert(this->data.end(), this->prev_data.begin(), + this->prev_data.end()); + } } // Add audio if listening @@ -192,10 +201,8 @@ void SileroVadNode::audio_callback( this->data.insert(this->data.end(), data.begin(), data.end()); } - // Check if speech ends - if (timestamp.start == -1 && timestamp.end != -1 && this->listening) { - RCLCPP_INFO(this->get_logger(), "Speech ends..."); - + // Check if publish + if (this->publish) { if (this->data.size() / msg->audio.info.rate < 1.0) { int pad_size = msg->audio.info.chunk + msg->audio.info.rate - this->data.size(); @@ -203,11 +210,20 @@ void SileroVadNode::audio_callback( } this->listening.store(false); + this->publish.store(false); auto vad_msg = std_msgs::msg::Float32MultiArray(); vad_msg.data = this->data; this->publisher_->publish(vad_msg); this->data.clear(); } + + // Check if speech ends + if (timestamp.start == -1 && timestamp.end != -1 && this->listening) { + RCLCPP_INFO(this->get_logger(), "Speech ends..."); + this->publish.store(true); + } + + this->prev_data = data; } void SileroVadNode::enable_cb( @@ -218,8 +234,11 @@ void SileroVadNode::enable_cb( if (request->data && !this->enabled) { response->message = "SileroVAD enabled"; - this->enabled = true; + this->enabled.store(true); + this->listening.store(false); + this->publish.store(false); this->data.clear(); + this->prev_data.clear(); this->vad_iterator->reset_states(); } else if (request->data && this->enabled) { @@ -227,8 +246,11 @@ void SileroVadNode::enable_cb( } else if (!request->data && this->enabled) { response->message = "SileroVAD disabled"; + this->enabled.store(false); this->listening.store(false); + this->publish.store(false); this->data.clear(); + this->prev_data.clear(); } else if (!request->data && !this->enabled) { response->message = "SileroVAD already disabled"; From 0eedfbb1a185e5543a9133632f2961b1f3224663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Thu, 26 Dec 2024 22:15:35 +0100 Subject: [PATCH 09/18] removing Python silero version --- whisper_ros/whisper_ros/__init__.py | 0 whisper_ros/whisper_ros/silero_vad_node.py | 187 --------------------- 2 files changed, 187 deletions(-) delete mode 100644 whisper_ros/whisper_ros/__init__.py delete mode 100644 whisper_ros/whisper_ros/silero_vad_node.py diff --git a/whisper_ros/whisper_ros/__init__.py b/whisper_ros/whisper_ros/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/whisper_ros/whisper_ros/silero_vad_node.py b/whisper_ros/whisper_ros/silero_vad_node.py deleted file mode 100644 index cd0e045..0000000 --- a/whisper_ros/whisper_ros/silero_vad_node.py +++ /dev/null @@ -1,187 +0,0 @@ -#!/usr/bin/env python3 - -# MIT License - -# Copyright (c) 2023 Miguel Ángel González Santamarta - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - -import pyaudio -import numpy as np -from typing import List -from silero_vad import VADIterator, load_silero_vad - -import rclpy -from rclpy.node import Node -from rclpy.qos import qos_profile_sensor_data -from std_msgs.msg import Float32MultiArray -from std_srvs.srv import SetBool - -from audio_common_msgs.msg import Audio -from audio_common_msgs.msg import AudioStamped - - -def int2float(audio_data: np.ndarray) -> np.ndarray: - - if audio_data is None or len(audio_data) == 0: - return None - - from_type = type(audio_data[0]) - - if from_type == np.uint8: - audio_data = (audio_data.astype(np.float32) - 128) / 128.0 - - elif from_type == np.int8: - audio_data = audio_data.astype(np.float32) / 128.0 - - elif from_type == np.int16: - audio_data = audio_data.astype(np.float32) / 32768.0 - - elif from_type == np.int32: - audio_data = audio_data.astype(np.float32) / 2147483648.0 - - elif from_type == np.float32: - audio_data = audio_data - - else: - return None - - return audio_data - - -def msg_to_array(msg: Audio) -> np.ndarray: - data = None - audio_format = msg.info.format - pyaudio_to_np = { - pyaudio.paFloat32: np.float32, - pyaudio.paInt32: np.int32, - pyaudio.paInt16: np.int16, - pyaudio.paInt8: np.int8, - pyaudio.paUInt8: np.uint8, - } - - if audio_format == pyaudio.paFloat32: - data = msg.audio_data.float32_data - elif audio_format == pyaudio.paInt32: - data = msg.audio_data.int32_data - elif audio_format == pyaudio.paInt16: - data = msg.audio_data.int16_data - elif audio_format == pyaudio.paInt8: - data = msg.audio_data.int8_data - elif audio_format == pyaudio.paUInt8: - data = msg.audio_data.uint8_data - if data is not None: - data = np.frombuffer(data, pyaudio_to_np[audio_format]) - - return data - - -class SileroVadNode(Node): - - def __init__(self) -> None: - - super().__init__("silero_vad_node") - - self.recording = False - self.data: List[float] = [] - - self.declare_parameter("enabled", True) - self.enabled = self.chunk = ( - self.get_parameter("enabled").get_parameter_value().bool_value - ) - - self.declare_parameter("threshold", 0.5) - self.threshold = self.chunk = ( - self.get_parameter("threshold").get_parameter_value().double_value - ) - - # create silero model - model = load_silero_vad(onnx=True) - self.vad_iterator = VADIterator(model, threshold=self.threshold) - - # srvs, subs, pubs - self._enable_srv = self.create_service(SetBool, "enable_vad", self.enable_cb) - self._pub = self.create_publisher(Float32MultiArray, "vad", 10) - self._sub = self.create_subscription( - AudioStamped, "audio", self.audio_cb, qos_profile_sensor_data - ) - - self.get_logger().info("Silero VAD node started") - - def audio_cb(self, msg: AudioStamped) -> None: - - if not self.enabled: - return - - audio_array = int2float(msg_to_array(msg.audio)) - - if audio_array is None: - self.get_logger().error(f"Format {msg.audio.info.format} unknown") - return - speech_dict = self.vad_iterator(audio_array) - - if speech_dict: - self.get_logger().info(str(speech_dict)) - - if not self.recording and "start" in speech_dict: - self.recording = True - self.data = [] - - elif self.recording and "end" in speech_dict: - self.recording = False - self.data.extend(audio_array.tolist()) - - if len(self.data) / msg.audio.info.rate < 1.0: - pad_size = msg.audio.info.chunk + msg.audio.info.rate - len(self.data) - self.data.extend(pad_size * [0.0]) - - vad_msg = Float32MultiArray() - vad_msg.data = self.data - self._pub.publish(vad_msg) - - if self.recording: - self.data.extend(audio_array.tolist()) - - def enable_cb(self, req: SetBool.Request, res: SetBool.Response) -> SetBool.Response: - res.success = True - self.enabled = req.data - - if self.enabled: - res.message = "Silero enabled" - self.vad_iterator.reset_states() - else: - res.message = "Silero disabled" - self.recording = False - self.data = [] - - self.get_logger().info(res.message) - return res - - -def main(): - rclpy.init() - node = SileroVadNode() - rclpy.spin(node) - node.destroy_node() - rclpy.shutdown() - - -if __name__ == "__main__": - main() From 1f7ff0512051e511edf5caab2dc349b29467cbf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 11:39:58 +0100 Subject: [PATCH 10/18] improving vad publishing --- whisper_ros/src/silero_vad/silero_vad_node.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/whisper_ros/src/silero_vad/silero_vad_node.cpp b/whisper_ros/src/silero_vad/silero_vad_node.cpp index b72e2ba..31ca13e 100644 --- a/whisper_ros/src/silero_vad/silero_vad_node.cpp +++ b/whisper_ros/src/silero_vad/silero_vad_node.cpp @@ -203,17 +203,19 @@ void SileroVadNode::audio_callback( // Check if publish if (this->publish) { - if (this->data.size() / msg->audio.info.rate < 1.0) { + auto vad_msg = std_msgs::msg::Float32MultiArray(); + vad_msg.data.assign(this->data.begin(), this->data.end()); + + if (vad_msg.data.size() / msg->audio.info.rate < 1.0) { int pad_size = msg->audio.info.chunk + msg->audio.info.rate - this->data.size(); - this->data.insert(this->data.end(), pad_size, 0.0f); + vad_msg.data.insert(vad_msg.data.end(), pad_size, 0.0f); } + this->publisher_->publish(vad_msg); + this->listening.store(false); this->publish.store(false); - auto vad_msg = std_msgs::msg::Float32MultiArray(); - vad_msg.data = this->data; - this->publisher_->publish(vad_msg); this->data.clear(); } From a40cbc9c0adb711f9c4c122baff19348d112e2e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 11:42:25 +0100 Subject: [PATCH 11/18] fixing ifndef guards names --- whisper_ros/include/silero_vad/silero_vad_node.hpp | 4 ++-- whisper_ros/include/silero_vad/timestamp.hpp | 4 ++-- whisper_ros/include/silero_vad/vad_iterator.hpp | 4 ++-- whisper_ros/include/whisper_ros/whisper.hpp | 4 ++-- whisper_ros/include/whisper_ros/whisper_base_node.hpp | 4 ++-- whisper_ros/include/whisper_ros/whisper_node.hpp | 4 ++-- whisper_ros/include/whisper_ros/whisper_server_node.hpp | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp index 3ada436..cc2de31 100644 --- a/whisper_ros/include/silero_vad/silero_vad_node.hpp +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef SILERO_VAD_NODE_HPP -#define SILERO_VAD_NODE_HPP +#ifndef SILERO_VAD_SILERO_VAD_NODE_HPP +#define SILERO_VAD_SILERO_VAD_NODE_HPP #include #include diff --git a/whisper_ros/include/silero_vad/timestamp.hpp b/whisper_ros/include/silero_vad/timestamp.hpp index 030a8bc..daa9c3e 100644 --- a/whisper_ros/include/silero_vad/timestamp.hpp +++ b/whisper_ros/include/silero_vad/timestamp.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef TIMESTAMPT_HPP -#define TIMESTAMPT_HPP +#ifndef SILERO_VAD__TIMESTAMPT_HPP +#define SILERO_VAD__TIMESTAMPT_HPP #include diff --git a/whisper_ros/include/silero_vad/vad_iterator.hpp b/whisper_ros/include/silero_vad/vad_iterator.hpp index b6f2911..2e22fda 100644 --- a/whisper_ros/include/silero_vad/vad_iterator.hpp +++ b/whisper_ros/include/silero_vad/vad_iterator.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef VAD_ITERATOR_HPP -#define VAD_ITERATOR_HPP +#ifndef SILERO_VAD__VAD_ITERATOR_HPP +#define SILERO_VAD__VAD_ITERATOR_HPP #include #include diff --git a/whisper_ros/include/whisper_ros/whisper.hpp b/whisper_ros/include/whisper_ros/whisper.hpp index 3adc707..3af0148 100644 --- a/whisper_ros/include/whisper_ros/whisper.hpp +++ b/whisper_ros/include/whisper_ros/whisper.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef WHISPER_HPP -#define WHISPER_HPP +#ifndef WHISPER_ROS__WHISPER_HPP +#define WHISPER_ROS__WHISPER_HPP #include #include diff --git a/whisper_ros/include/whisper_ros/whisper_base_node.hpp b/whisper_ros/include/whisper_ros/whisper_base_node.hpp index 68c58e9..bf57db6 100644 --- a/whisper_ros/include/whisper_ros/whisper_base_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_base_node.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef WHISPER_BASE_NODE_HPP -#define WHISPER_BASE_NODE_HPP +#ifndef WHISPER_ROS__WHISPER_BASE_NODE_HPP +#define WHISPER_ROS__WHISPER_BASE_NODE_HPP #include #include diff --git a/whisper_ros/include/whisper_ros/whisper_node.hpp b/whisper_ros/include/whisper_ros/whisper_node.hpp index ab71c20..77761bf 100644 --- a/whisper_ros/include/whisper_ros/whisper_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_node.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef WHISPER_NODE_HPP -#define WHISPER_NODE_HPP +#ifndef WHISPER_ROS__WHISPER_NODE_HPP +#define WHISPER_ROS__WHISPER_NODE_HPP #include diff --git a/whisper_ros/include/whisper_ros/whisper_server_node.hpp b/whisper_ros/include/whisper_ros/whisper_server_node.hpp index c3ea1ee..3d7bd24 100644 --- a/whisper_ros/include/whisper_ros/whisper_server_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_server_node.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef WHISPER_SERVER_NODE_HPP -#define WHISPER_SERVER_NODE_HPP +#ifndef WHISPER_ROS__WHISPER_SERVER_NODE_HPP +#define WHISPER_ROS__WHISPER_SERVER_NODE_HPP #include #include From 4a77cf27431a04b3506479bb76a90aa3fd08a3bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 11:48:02 +0100 Subject: [PATCH 12/18] fixing license comments --- whisper_ros/include/silero_vad/silero_vad_node.hpp | 8 ++++---- whisper_ros/include/silero_vad/timestamp.hpp | 8 ++++---- whisper_ros/include/silero_vad/vad_iterator.hpp | 8 ++++---- whisper_ros/include/whisper_ros/whisper.hpp | 8 ++++---- whisper_ros/include/whisper_ros/whisper_base_node.hpp | 8 ++++---- whisper_ros/include/whisper_ros/whisper_node.hpp | 8 ++++---- .../include/whisper_ros/whisper_server_node.hpp | 8 ++++---- whisper_ros/src/silero_vad/silero_vad_node.cpp | 8 ++++---- whisper_ros/src/silero_vad/timestamp.cpp | 8 ++++---- whisper_ros/src/silero_vad/vad_iterator.cpp | 8 ++++---- whisper_ros/src/silero_vad_main.cpp | 10 +++++----- whisper_ros/src/whisper_main.cpp | 8 ++++---- whisper_ros/src/whisper_ros/whisper.cpp | 8 ++++---- whisper_ros/src/whisper_ros/whisper_base_node.cpp | 8 ++++---- whisper_ros/src/whisper_ros/whisper_node.cpp | 8 ++++---- whisper_ros/src/whisper_ros/whisper_server_node.cpp | 8 ++++---- whisper_ros/src/whisper_server_main.cpp | 8 ++++---- 17 files changed, 69 insertions(+), 69 deletions(-) diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp index cc2de31..ba10e97 100644 --- a/whisper_ros/include/silero_vad/silero_vad_node.hpp +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/include/silero_vad/timestamp.hpp b/whisper_ros/include/silero_vad/timestamp.hpp index daa9c3e..cd41b08 100644 --- a/whisper_ros/include/silero_vad/timestamp.hpp +++ b/whisper_ros/include/silero_vad/timestamp.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/include/silero_vad/vad_iterator.hpp b/whisper_ros/include/silero_vad/vad_iterator.hpp index 2e22fda..a8345d4 100644 --- a/whisper_ros/include/silero_vad/vad_iterator.hpp +++ b/whisper_ros/include/silero_vad/vad_iterator.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/include/whisper_ros/whisper.hpp b/whisper_ros/include/whisper_ros/whisper.hpp index 3af0148..3eb4e4d 100644 --- a/whisper_ros/include/whisper_ros/whisper.hpp +++ b/whisper_ros/include/whisper_ros/whisper.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/include/whisper_ros/whisper_base_node.hpp b/whisper_ros/include/whisper_ros/whisper_base_node.hpp index bf57db6..aab1c12 100644 --- a/whisper_ros/include/whisper_ros/whisper_base_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_base_node.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/include/whisper_ros/whisper_node.hpp b/whisper_ros/include/whisper_ros/whisper_node.hpp index 77761bf..d6ea89b 100644 --- a/whisper_ros/include/whisper_ros/whisper_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_node.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/include/whisper_ros/whisper_server_node.hpp b/whisper_ros/include/whisper_ros/whisper_server_node.hpp index 3d7bd24..555d62d 100644 --- a/whisper_ros/include/whisper_ros/whisper_server_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_server_node.hpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/silero_vad/silero_vad_node.cpp b/whisper_ros/src/silero_vad/silero_vad_node.cpp index 31ca13e..87681c4 100644 --- a/whisper_ros/src/silero_vad/silero_vad_node.cpp +++ b/whisper_ros/src/silero_vad/silero_vad_node.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/silero_vad/timestamp.cpp b/whisper_ros/src/silero_vad/timestamp.cpp index 5ed9f67..7c5072a 100644 --- a/whisper_ros/src/silero_vad/timestamp.cpp +++ b/whisper_ros/src/silero_vad/timestamp.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index 04fca36..c6669a0 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/silero_vad_main.cpp b/whisper_ros/src/silero_vad_main.cpp index 982ef01..fc30b85 100644 --- a/whisper_ros/src/silero_vad_main.cpp +++ b/whisper_ros/src/silero_vad_main.cpp @@ -1,17 +1,17 @@ // MIT License - -// Copyright (c) 2023 Miguel Ángel González Santamarta - +// +// Copyright (c) 2024 Miguel Ángel González Santamarta +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/whisper_main.cpp b/whisper_ros/src/whisper_main.cpp index 55f0f53..699d0c7 100644 --- a/whisper_ros/src/whisper_main.cpp +++ b/whisper_ros/src/whisper_main.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/whisper_ros/whisper.cpp b/whisper_ros/src/whisper_ros/whisper.cpp index 1999ba2..8106735 100644 --- a/whisper_ros/src/whisper_ros/whisper.cpp +++ b/whisper_ros/src/whisper_ros/whisper.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/whisper_ros/whisper_base_node.cpp b/whisper_ros/src/whisper_ros/whisper_base_node.cpp index 907d023..527eba7 100644 --- a/whisper_ros/src/whisper_ros/whisper_base_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_base_node.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/whisper_ros/whisper_node.cpp b/whisper_ros/src/whisper_ros/whisper_node.cpp index ecc1f7f..258b9c0 100644 --- a/whisper_ros/src/whisper_ros/whisper_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_node.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2023 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/whisper_ros/whisper_server_node.cpp b/whisper_ros/src/whisper_ros/whisper_server_node.cpp index 4339df4..c0eba67 100644 --- a/whisper_ros/src/whisper_ros/whisper_server_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_server_node.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE diff --git a/whisper_ros/src/whisper_server_main.cpp b/whisper_ros/src/whisper_server_main.cpp index 14627e2..81aff7f 100644 --- a/whisper_ros/src/whisper_server_main.cpp +++ b/whisper_ros/src/whisper_server_main.cpp @@ -1,17 +1,17 @@ // MIT License - +// // Copyright (c) 2024 Miguel Ángel González Santamarta - +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: - +// // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. - +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE From 520de600b3eab64eb38b41326523c7f848ff4d24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 11:50:15 +0100 Subject: [PATCH 13/18] struct TranscriptionOutput name fixed --- whisper_ros/include/whisper_ros/whisper.hpp | 4 ++-- whisper_ros/src/whisper_ros/whisper.cpp | 4 ++-- whisper_ros/src/whisper_ros/whisper_base_node.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/whisper_ros/include/whisper_ros/whisper.hpp b/whisper_ros/include/whisper_ros/whisper.hpp index 3eb4e4d..01ffdec 100644 --- a/whisper_ros/include/whisper_ros/whisper.hpp +++ b/whisper_ros/include/whisper_ros/whisper.hpp @@ -37,7 +37,7 @@ #define WHISPER_LOG_INFO(text, ...) \ fprintf(stderr, "[INFO] " text "\n", ##__VA_ARGS__) -struct transcription_output { +struct TranscriptionOutput { std::string text; float prob; }; @@ -52,7 +52,7 @@ class Whisper { const struct whisper_full_params &wparams); ~Whisper(); - struct transcription_output transcribe(const std::vector &pcmf32); + struct TranscriptionOutput transcribe(const std::vector &pcmf32); std::string trim(const std::string &s); std::string timestamp_to_str(int64_t t, bool comma = false); diff --git a/whisper_ros/src/whisper_ros/whisper.cpp b/whisper_ros/src/whisper_ros/whisper.cpp index 8106735..a8a0ab2 100644 --- a/whisper_ros/src/whisper_ros/whisper.cpp +++ b/whisper_ros/src/whisper_ros/whisper.cpp @@ -72,11 +72,11 @@ Whisper::Whisper(const std::string &model, Whisper::~Whisper() { whisper_free(this->ctx); } -struct transcription_output +struct TranscriptionOutput Whisper::transcribe(const std::vector &pcmf32) { int prob_n = 0; - struct transcription_output result; + struct TranscriptionOutput result; result.text = ""; result.prob = 0.0f; diff --git a/whisper_ros/src/whisper_ros/whisper_base_node.cpp b/whisper_ros/src/whisper_ros/whisper_base_node.cpp index 527eba7..2f6ce8b 100644 --- a/whisper_ros/src/whisper_ros/whisper_base_node.cpp +++ b/whisper_ros/src/whisper_ros/whisper_base_node.cpp @@ -264,7 +264,7 @@ WhisperBaseNode::transcribe(const std::vector &audio) { auto start_time = this->get_clock()->now(); RCLCPP_INFO(this->get_logger(), "Transcribing"); - transcription_output result = this->whisper->transcribe(audio); + struct TranscriptionOutput result = this->whisper->transcribe(audio); std::string text = this->whisper->trim(result.text); auto end_time = this->get_clock()->now(); From 6c608e3a724f1d3da82da43b701d2ba1222be35a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 11:59:27 +0100 Subject: [PATCH 14/18] new whisper logs --- whisper_ros/CMakeLists.txt | 2 + whisper_ros/include/whisper_ros/logs.hpp | 80 +++++++++++++ whisper_ros/include/whisper_ros/whisper.hpp | 8 -- whisper_ros/src/whisper_ros/logs.cpp | 118 ++++++++++++++++++++ whisper_ros/src/whisper_ros/whisper.cpp | 14 ++- 5 files changed, 212 insertions(+), 10 deletions(-) create mode 100644 whisper_ros/include/whisper_ros/logs.hpp create mode 100644 whisper_ros/src/whisper_ros/logs.cpp diff --git a/whisper_ros/CMakeLists.txt b/whisper_ros/CMakeLists.txt index 78045b8..8864ddc 100644 --- a/whisper_ros/CMakeLists.txt +++ b/whisper_ros/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(whisper_node src/whisper_ros/whisper_node.cpp src/whisper_ros/whisper_base_node.cpp src/whisper_ros/whisper.cpp + src/whisper_ros/logs.cpp ) target_link_libraries(whisper_node whisper_cpp_vendor::grammar @@ -49,6 +50,7 @@ add_executable(whisper_server_node src/whisper_ros/whisper_server_node.cpp src/whisper_ros/whisper_base_node.cpp src/whisper_ros/whisper.cpp + src/whisper_ros/logs.cpp ) target_link_libraries(whisper_server_node whisper_cpp_vendor::grammar diff --git a/whisper_ros/include/whisper_ros/logs.hpp b/whisper_ros/include/whisper_ros/logs.hpp new file mode 100644 index 0000000..c35f15c --- /dev/null +++ b/whisper_ros/include/whisper_ros/logs.hpp @@ -0,0 +1,80 @@ +// Copyright (C) 2024 Miguel Ángel González Santamarta +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#ifndef WHISPER_ROS__LOGS_HPP +#define WHISPER_ROS__LOGS_HPP + +#include +#include +#include + +namespace whisper_ros { + +/** + * @brief Type definition for a logging function. + * + * This type represents a function pointer that takes a file name, + * function name, line number, log message, and a variable number of + * additional arguments for formatting the log message. + * + * @param file The name of the source file where the log function is called. + * @param function The name of the function where the log function is called. + * @param line The line number in the source file where the log function is + * called. + * @param text The format string for the log message, similar to printf. + * @param ... Additional arguments for the format string. + */ +typedef void (*LogFunction)(const char *file, const char *function, int line, + const char *text, ...); + +// Declare function pointers for logging at different severity levels +extern LogFunction log_error; ///< Pointer to the error logging function +extern LogFunction log_warn; ///< Pointer to the warning logging function +extern LogFunction log_info; ///< Pointer to the info logging function +extern LogFunction log_debug; ///< Pointer to the debug logging function + +/** + * @brief Extracts the filename from a given file path. + * + * This function takes a full path to a file and returns just the file name. + * + * @param path The full path to the file. + * @return A pointer to the extracted filename. + */ +inline const char *extract_filename(const char *path) { + const char *filename = std::strrchr(path, '/'); + if (!filename) { + filename = std::strrchr(path, '\\'); // handle Windows-style paths + } + return filename ? filename + 1 : path; +} + +// Macros for logging with automatic file and function information +#define WHISPER_LOG_ERROR(text, ...) \ + whisper_ros::log_error(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ + text, ##__VA_ARGS__) +#define WHISPER_LOG_WARN(text, ...) \ + whisper_ros::log_warn(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ + text, ##__VA_ARGS__) +#define WHISPER_LOG_INFO(text, ...) \ + whisper_ros::log_info(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ + text, ##__VA_ARGS__) +#define WHISPER_LOG_DEBUG(text, ...) \ + whisper_ros::log_debug(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ + text, ##__VA_ARGS__) + +} // namespace whisper_ros + +#endif // WHISPER_ROS__LOGS_HPP \ No newline at end of file diff --git a/whisper_ros/include/whisper_ros/whisper.hpp b/whisper_ros/include/whisper_ros/whisper.hpp index 01ffdec..7e2ef40 100644 --- a/whisper_ros/include/whisper_ros/whisper.hpp +++ b/whisper_ros/include/whisper_ros/whisper.hpp @@ -29,14 +29,6 @@ #include "grammar-parser.h" #include "whisper.h" -// whisper logs -#define WHISPER_LOG_ERROR(text, ...) \ - fprintf(stderr, "[ERROR] " text "\n", ##__VA_ARGS__) -#define WHISPER_LOG_WARN(text, ...) \ - fprintf(stderr, "[WARN] " text "\n", ##__VA_ARGS__) -#define WHISPER_LOG_INFO(text, ...) \ - fprintf(stderr, "[INFO] " text "\n", ##__VA_ARGS__) - struct TranscriptionOutput { std::string text; float prob; diff --git a/whisper_ros/src/whisper_ros/logs.cpp b/whisper_ros/src/whisper_ros/logs.cpp new file mode 100644 index 0000000..5232493 --- /dev/null +++ b/whisper_ros/src/whisper_ros/logs.cpp @@ -0,0 +1,118 @@ +// Copyright (C) 2024 Miguel Ángel González Santamarta +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "whisper_ros/logs.hpp" + +namespace whisper_ros { + +/** + * @brief Default error logging function. + * + * This function logs an error message to stderr with the format: + * [ERROR] [file:function:line] message. + * + * @param file The name of the source file where the log function is called. + * @param function The name of the function where the log function is called. + * @param line The line number in the source file where the log function is + * called. + * @param text The format string for the log message. + * @param ... Additional arguments for the format string. + */ +void default_log_error(const char *file, const char *function, int line, + const char *text, ...) { + va_list args; + va_start(args, text); + fprintf(stderr, "[ERROR] [%s:%s:%d] ", file, function, line); + vfprintf(stderr, text, args); + fprintf(stderr, "\n"); + va_end(args); +} + +/** + * @brief Default warning logging function. + * + * This function logs a warning message to stderr with the format: + * [WARN] [file:function:line] message. + * + * @param file The name of the source file where the log function is called. + * @param function The name of the function where the log function is called. + * @param line The line number in the source file where the log function is + * called. + * @param text The format string for the log message. + * @param ... Additional arguments for the format string. + */ +void default_log_warn(const char *file, const char *function, int line, + const char *text, ...) { + va_list args; + va_start(args, text); + fprintf(stderr, "[WARN] [%s:%s:%d] ", file, function, line); + vfprintf(stderr, text, args); + fprintf(stderr, "\n"); + va_end(args); +} + +/** + * @brief Default info logging function. + * + * This function logs an informational message to stderr with the format: + * [INFO] [file:function:line] message. + * + * @param file The name of the source file where the log function is called. + * @param function The name of the function where the log function is called. + * @param line The line number in the source file where the log function is + * called. + * @param text The format string for the log message. + * @param ... Additional arguments for the format string. + */ +void default_log_info(const char *file, const char *function, int line, + const char *text, ...) { + va_list args; + va_start(args, text); + fprintf(stderr, "[INFO] [%s:%s:%d] ", file, function, line); + vfprintf(stderr, text, args); + fprintf(stderr, "\n"); + va_end(args); +} + +/** + * @brief Default debug logging function. + * + * This function logs a debug message to stderr with the format: + * [DEBUG] [file:function:line] message. + * + * @param file The name of the source file where the log function is called. + * @param function The name of the function where the log function is called. + * @param line The line number in the source file where the log function is + * called. + * @param text The format string for the log message. + * @param ... Additional arguments for the format string. + */ +void default_log_debug(const char *file, const char *function, int line, + const char *text, ...) { + va_list args; + va_start(args, text); + fprintf(stderr, "[DEBUG] [%s:%s:%d] ", file, function, line); + vfprintf(stderr, text, args); + fprintf(stderr, "\n"); + va_end(args); +} + +// Initialize the function pointers with default log functions +LogFunction log_error = default_log_error; +LogFunction log_warn = default_log_warn; +LogFunction log_info = default_log_info; +LogFunction log_debug = default_log_debug; + +} // namespace whisper_ros \ No newline at end of file diff --git a/whisper_ros/src/whisper_ros/whisper.cpp b/whisper_ros/src/whisper_ros/whisper.cpp index a8a0ab2..15fdb50 100644 --- a/whisper_ros/src/whisper_ros/whisper.cpp +++ b/whisper_ros/src/whisper_ros/whisper.cpp @@ -24,6 +24,7 @@ #include #include "grammar-parser.h" +#include "whisper_ros/logs.hpp" #include "whisper_ros/whisper.hpp" using namespace whisper_ros; @@ -42,7 +43,7 @@ Whisper::Whisper(const std::string &model, this->ctx = whisper_init_from_file_with_params(model.c_str(), cparams); if (this->ctx == nullptr) { - WHISPER_LOG_ERROR("failed to initialize whisper context\n"); + WHISPER_LOG_ERROR("Failed to initialize whisper context\n"); } if (!whisper_is_multilingual(this->ctx)) { @@ -75,6 +76,8 @@ Whisper::~Whisper() { whisper_free(this->ctx); } struct TranscriptionOutput Whisper::transcribe(const std::vector &pcmf32) { + WHISPER_LOG_DEBUG("Starting transcription"); + int prob_n = 0; struct TranscriptionOutput result; result.text = ""; @@ -137,9 +140,11 @@ std::string Whisper::timestamp_to_str(int64_t t, bool comma) { bool Whisper::set_grammar(const std::string grammar, const std::string start_rule, float grammar_penalty) { + WHISPER_LOG_DEBUG("Setting new grammar"); this->grammar_parsed = grammar_parser::parse(grammar.c_str()); if (this->grammar_parsed.rules.empty()) { + WHISPER_LOG_ERROR("Error setting the grammar"); return false; } @@ -155,6 +160,7 @@ bool Whisper::set_grammar(const std::string grammar, } void Whisper::reset_grammar() { + WHISPER_LOG_DEBUG("Resetting grammar"); this->wparams.grammar_rules = nullptr; this->wparams.n_grammar_rules = 0; this->wparams.i_start_rule = 0; @@ -162,7 +168,11 @@ void Whisper::reset_grammar() { } void Whisper::set_init_prompt(const std::string prompt) { + WHISPER_LOG_DEBUG("Resetting initial prompt"); this->wparams.initial_prompt = prompt.c_str(); } -void Whisper::reset_init_prompt() { this->wparams.initial_prompt = ""; } +void Whisper::reset_init_prompt() { + WHISPER_LOG_DEBUG("Resetting initial prompt"); + this->wparams.initial_prompt = ""; +} From a71f68215470089a3428b0c4e999e6c9abad4062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 12:06:30 +0100 Subject: [PATCH 15/18] whisper_utils created --- whisper_ros/CMakeLists.txt | 4 +-- .../{whisper_ros => whisper_utils}/logs.hpp | 26 +++++++++---------- whisper_ros/src/whisper_ros/whisper.cpp | 2 +- .../{whisper_ros => whisper_utils}/logs.cpp | 6 ++--- 4 files changed, 19 insertions(+), 19 deletions(-) rename whisper_ros/include/{whisper_ros => whisper_utils}/logs.hpp (78%) rename whisper_ros/src/{whisper_ros => whisper_utils}/logs.cpp (97%) diff --git a/whisper_ros/CMakeLists.txt b/whisper_ros/CMakeLists.txt index 8864ddc..22332cd 100644 --- a/whisper_ros/CMakeLists.txt +++ b/whisper_ros/CMakeLists.txt @@ -29,7 +29,7 @@ add_executable(whisper_node src/whisper_ros/whisper_node.cpp src/whisper_ros/whisper_base_node.cpp src/whisper_ros/whisper.cpp - src/whisper_ros/logs.cpp + src/whisper_utils/logs.cpp ) target_link_libraries(whisper_node whisper_cpp_vendor::grammar @@ -50,7 +50,7 @@ add_executable(whisper_server_node src/whisper_ros/whisper_server_node.cpp src/whisper_ros/whisper_base_node.cpp src/whisper_ros/whisper.cpp - src/whisper_ros/logs.cpp + src/whisper_utils/logs.cpp ) target_link_libraries(whisper_server_node whisper_cpp_vendor::grammar diff --git a/whisper_ros/include/whisper_ros/logs.hpp b/whisper_ros/include/whisper_utils/logs.hpp similarity index 78% rename from whisper_ros/include/whisper_ros/logs.hpp rename to whisper_ros/include/whisper_utils/logs.hpp index c35f15c..b2b6b98 100644 --- a/whisper_ros/include/whisper_ros/logs.hpp +++ b/whisper_ros/include/whisper_utils/logs.hpp @@ -13,14 +13,14 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -#ifndef WHISPER_ROS__LOGS_HPP -#define WHISPER_ROS__LOGS_HPP +#ifndef WHISPER_UTILS__LOGS_HPP +#define WHISPER_UTILS__LOGS_HPP #include #include #include -namespace whisper_ros { +namespace whisper_utils { /** * @brief Type definition for a logging function. @@ -63,18 +63,18 @@ inline const char *extract_filename(const char *path) { // Macros for logging with automatic file and function information #define WHISPER_LOG_ERROR(text, ...) \ - whisper_ros::log_error(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ - text, ##__VA_ARGS__) + whisper_utils::log_error(whisper_utils::extract_filename(__FILE__), \ + __FUNCTION__, __LINE__, text, ##__VA_ARGS__) #define WHISPER_LOG_WARN(text, ...) \ - whisper_ros::log_warn(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ - text, ##__VA_ARGS__) + whisper_utils::log_warn(whisper_utils::extract_filename(__FILE__), \ + __FUNCTION__, __LINE__, text, ##__VA_ARGS__) #define WHISPER_LOG_INFO(text, ...) \ - whisper_ros::log_info(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ - text, ##__VA_ARGS__) + whisper_utils::log_info(whisper_utils::extract_filename(__FILE__), \ + __FUNCTION__, __LINE__, text, ##__VA_ARGS__) #define WHISPER_LOG_DEBUG(text, ...) \ - whisper_ros::log_debug(extract_filename(__FILE__), __FUNCTION__, __LINE__, \ - text, ##__VA_ARGS__) + whisper_utils::log_debug(whisper_utils::extract_filename(__FILE__), \ + __FUNCTION__, __LINE__, text, ##__VA_ARGS__) -} // namespace whisper_ros +} // namespace whisper_utils -#endif // WHISPER_ROS__LOGS_HPP \ No newline at end of file +#endif // WHISPER_UTILS__LOGS_HPP \ No newline at end of file diff --git a/whisper_ros/src/whisper_ros/whisper.cpp b/whisper_ros/src/whisper_ros/whisper.cpp index 15fdb50..db2d8f1 100644 --- a/whisper_ros/src/whisper_ros/whisper.cpp +++ b/whisper_ros/src/whisper_ros/whisper.cpp @@ -24,8 +24,8 @@ #include #include "grammar-parser.h" -#include "whisper_ros/logs.hpp" #include "whisper_ros/whisper.hpp" +#include "whisper_utils/logs.hpp" using namespace whisper_ros; diff --git a/whisper_ros/src/whisper_ros/logs.cpp b/whisper_ros/src/whisper_utils/logs.cpp similarity index 97% rename from whisper_ros/src/whisper_ros/logs.cpp rename to whisper_ros/src/whisper_utils/logs.cpp index 5232493..de68752 100644 --- a/whisper_ros/src/whisper_ros/logs.cpp +++ b/whisper_ros/src/whisper_utils/logs.cpp @@ -13,9 +13,9 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -#include "whisper_ros/logs.hpp" +#include "whisper_utils/logs.hpp" -namespace whisper_ros { +namespace whisper_utils { /** * @brief Default error logging function. @@ -115,4 +115,4 @@ LogFunction log_warn = default_log_warn; LogFunction log_info = default_log_info; LogFunction log_debug = default_log_debug; -} // namespace whisper_ros \ No newline at end of file +} // namespace whisper_utils \ No newline at end of file From 16bb40647368b01bab0fa4b76bf759f6d726d928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 12:24:44 +0100 Subject: [PATCH 16/18] improving code comments --- .../include/silero_vad/silero_vad_node.hpp | 63 ++++++++++++--- whisper_ros/include/silero_vad/timestamp.hpp | 21 +++++ .../include/silero_vad/vad_iterator.hpp | 72 ++++++++++++++++- whisper_ros/include/whisper_ros/whisper.hpp | 49 ++++++++++- .../include/whisper_ros/whisper_base_node.hpp | 81 +++++++++++++++++-- .../include/whisper_ros/whisper_node.hpp | 65 ++++++++++++++- .../whisper_ros/whisper_server_node.hpp | 62 +++++++++++++- 7 files changed, 391 insertions(+), 22 deletions(-) diff --git a/whisper_ros/include/silero_vad/silero_vad_node.hpp b/whisper_ros/include/silero_vad/silero_vad_node.hpp index ba10e97..cfc8c5c 100644 --- a/whisper_ros/include/silero_vad/silero_vad_node.hpp +++ b/whisper_ros/include/silero_vad/silero_vad_node.hpp @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#ifndef SILERO_VAD_SILERO_VAD_NODE_HPP -#define SILERO_VAD_SILERO_VAD_NODE_HPP +#ifndef SILERO_VAD__SILERO_VAD_NODE_HPP +#define SILERO_VAD__SILERO_VAD_NODE_HPP #include #include @@ -36,50 +36,95 @@ namespace silero_vad { +/// @class SileroVadNode +/// @brief A ROS 2 lifecycle node for performing voice activity detection. class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { public: + /// @brief Constructs a new SileroVadNode object. SileroVadNode(); + /// @brief Callback for configuring the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of configuration. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_configure(const rclcpp_lifecycle::State &); + on_configure(const rclcpp_lifecycle::State &state); + + /// @brief Callback for activating the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of activation. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_activate(const rclcpp_lifecycle::State &); + on_activate(const rclcpp_lifecycle::State &state); + + /// @brief Callback for deactivating the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of deactivation. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_deactivate(const rclcpp_lifecycle::State &); + on_deactivate(const rclcpp_lifecycle::State &state); + + /// @brief Callback for cleaning up the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of cleanup. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_cleanup(const rclcpp_lifecycle::State &); + on_cleanup(const rclcpp_lifecycle::State &state); + + /// @brief Callback for shutting down the lifecycle node. + /// @param state The current state of the node. + /// @return Success or failure of shutdown. rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_shutdown(const rclcpp_lifecycle::State &); + on_shutdown(const rclcpp_lifecycle::State &state); protected: + /// Indicates if VAD is enabled. std::atomic enabled; + /// Indicates if VAD is in listening mode. std::atomic listening; + /// Indicates if audio data should be published. std::atomic publish; + /// Buffer for storing audio data. std::vector data; + /// Pointer to the VAD iterator. std::unique_ptr vad_iterator; private: + /// Buffer for storing previous audio data. std::vector prev_data; + /// Path to the VAD model. std::string model_path_; + /// Sampling rate of the audio data. int sample_rate_; + /// Frame size in milliseconds. int frame_size_ms_; + /// Threshold for VAD decision-making. float threshold_; + /// Minimum silence duration in milliseconds. int min_silence_ms_; + /// Padding duration for detected speech in milliseconds. int speech_pad_ms_; + /// Publisher for VAD output. rclcpp::Publisher::SharedPtr publisher_; + /// Subscription for audio input. rclcpp::Subscription::SharedPtr subscription_; - + /// Service for enabling/disabling VAD. rclcpp::Service::SharedPtr enable_srv_; + /// @brief Callback for handling incoming audio data. + /// @param msg The audio message containing the audio data. void audio_callback(const audio_common_msgs::msg::AudioStamped::SharedPtr msg); + /// @brief Callback for enabling/disabling the VAD. + /// @param request The service request containing the desired enable state. + /// @param response The service response indicating success or failure. void enable_cb(const std::shared_ptr request, std::shared_ptr response); + /// @brief Converts audio data to a float vector normalized to [-1.0, 1.0]. + /// @tparam T The input audio data type. + /// @param input The input audio data. + /// @return A vector of normalized float audio data. template std::vector convert_to_float(const std::vector &input) { static_assert(std::is_integral::value, @@ -123,4 +168,4 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode { } // namespace silero_vad -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/silero_vad/timestamp.hpp b/whisper_ros/include/silero_vad/timestamp.hpp index cd41b08..c6e7f31 100644 --- a/whisper_ros/include/silero_vad/timestamp.hpp +++ b/whisper_ros/include/silero_vad/timestamp.hpp @@ -27,17 +27,38 @@ namespace silero_vad { +/// @class Timestamp +/// @brief Represents a time interval with speech probability. class Timestamp { public: + /// The start time of the interval, in milliseconds. int start; + + /// The end time of the interval, in milliseconds. int end; + + /// The probability of speech detected in the interval, ranging from 0 to 1. float speech_prob; + /// @brief Constructs a `Timestamp` object. + /// @param start The start time of the interval (default: -1). + /// @param end The end time of the interval (default: -1). + /// @param speech_prob The speech probability (default: 0). Timestamp(int start = -1, int end = -1, float speech_prob = 0); + /// @brief Assigns the values of another `Timestamp` to this instance. + /// @param other The `Timestamp` to copy from. + /// @return A reference to this `Timestamp`. Timestamp &operator=(const Timestamp &other); + + /// @brief Compares two `Timestamp` objects for equality. + /// @param other The `Timestamp` to compare with. + /// @return `true` if the start and end times are equal; `false` otherwise. bool operator==(const Timestamp &other) const; + /// @brief Converts the `Timestamp` object to a string representation. + /// @return A string representing the `Timestamp` in the format + /// `{start:...,end:...,prob:...}`. std::string to_string() const; }; diff --git a/whisper_ros/include/silero_vad/vad_iterator.hpp b/whisper_ros/include/silero_vad/vad_iterator.hpp index a8345d4..6198921 100644 --- a/whisper_ros/include/silero_vad/vad_iterator.hpp +++ b/whisper_ros/include/silero_vad/vad_iterator.hpp @@ -33,58 +33,122 @@ namespace silero_vad { +/// @class VadIterator +/// @brief Implements a Voice Activity Detection (VAD) iterator using an ONNX +/// model. +/// +/// This class provides methods to load a pre-trained ONNX VAD model, process +/// audio data, and predict the presence of speech. It manages the model state +/// and handles input/output tensors for inference. class VadIterator { - public: + /// @brief Constructs a VadIterator object. + /// + /// @param model_path Path to the ONNX model file. + /// @param sample_rate The audio sample rate in Hz (default: 16000). + /// @param frame_size_ms Size of the audio frame in milliseconds (default: + /// 32). + /// @param threshold The threshold for speech detection (default: 0.5). + /// @param min_silence_ms Minimum silence duration in milliseconds to mark the + /// end of speech (default: 100). + /// @param speech_pad_ms Additional padding in milliseconds added to speech + /// segments (default: 30). VadIterator(const std::string &model_path, int sample_rate = 16000, int frame_size_ms = 32, float threshold = 0.5f, int min_silence_ms = 100, int speech_pad_ms = 30); + /// @brief Resets the internal state of the model. + /// + /// Clears the state, context, and resets internal flags related to speech + /// detection. void reset_states(); + + /// @brief Processes audio data and predicts speech segments. + /// + /// @param data A vector of audio samples (single-channel, float values). + /// @return A Timestamp object containing start and end times of detected + /// speech, or -1 for inactive values. Timestamp predict(const std::vector &data); private: + /// ONNX Runtime environment. Ort::Env env; + /// ONNX session options. Ort::SessionOptions session_options; + /// ONNX session for running inference. std::shared_ptr session; + /// Memory allocator for ONNX runtime. Ort::AllocatorWithDefaultOptions allocator; + /// Memory info for tensor allocation. Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); - // Model configuration + /// Detection threshold for speech probability. float threshold; + /// Audio sample rate in Hz. int sample_rate; + /// Samples per millisecond. int sr_per_ms; + /// Number of samples in a single frame. int64_t window_size_samples; + /// Padding in samples added to speech segments. int speech_pad_samples; + /// Minimum silence duration in samples to mark the end of speech. unsigned int min_silence_samples; + /// Size of the context buffer. int context_size; - // Model state + /// Indicates whether speech has been detected. bool triggered = false; + /// Temporary end position during silence detection. unsigned int temp_end = 0; + /// Current sample position in the input stream. unsigned int current_sample = 0; + /// End sample of the last speech segment. int prev_end = 0; + /// Start sample of the next speech segment. int next_start = 0; + /// ONNX model input tensors. std::vector ort_inputs; + + /// Names of the input nodes in the ONNX model. std::vector input_node_names = {"input", "state", "sr"}; + /// Input buffer for audio data and context. std::vector input; + + /// Context buffer storing past audio samples. std::vector context; + + /// Internal state of the model. std::vector state; + + /// Sample rate tensor. std::vector sr; + /// Dimensions for the input tensor. int64_t input_node_dims[2] = {}; + + /// Dimensions for the state tensor. const int64_t state_node_dims[3] = {2, 1, 128}; + + /// Dimensions for the sample rate tensor. const int64_t sr_node_dims[1] = {1}; + /// ONNX model output tensors. std::vector ort_outputs; + + /// Names of the output nodes in the ONNX model. std::vector output_node_names = {"output", "stateN"}; + /// @brief Initializes the ONNX model session. + /// + /// @param model_path Path to the ONNX model file. + /// @throws std::runtime_error If the ONNX session initialization fails. void init_onnx_model(const std::string &model_path); }; } // namespace silero_vad -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/whisper_ros/whisper.hpp b/whisper_ros/include/whisper_ros/whisper.hpp index 7e2ef40..dadae67 100644 --- a/whisper_ros/include/whisper_ros/whisper.hpp +++ b/whisper_ros/include/whisper_ros/whisper.hpp @@ -29,40 +29,87 @@ #include "grammar-parser.h" #include "whisper.h" +/// Represents the result of a transcription operation. struct TranscriptionOutput { + /// The transcribed text. std::string text; + + /// The confidence probability of the transcription. float prob; }; namespace whisper_ros { +/// Class for performing speech-to-text transcription using the Whisper model. class Whisper { public: + /// Constructs a Whisper object with the specified model and parameters. + /// @param model The path to the Whisper model file. + /// @param openvino_encode_device The OpenVINO device used for encoding. + /// @param n_processors Number of processors to use for parallel processing. + /// @param cparams Whisper context parameters. + /// @param wparams Whisper full parameters. Whisper(const std::string &model, const std::string &openvino_encode_device, int n_processors, const struct whisper_context_params &cparams, const struct whisper_full_params &wparams); + + /// Destructor to clean up resources used by the Whisper object. ~Whisper(); + /// Transcribes the given audio data. + /// @param pcmf32 A vector of audio samples in 32-bit float format. + /// @return A TranscriptionOutput structure containing the transcription text + /// and confidence probability. struct TranscriptionOutput transcribe(const std::vector &pcmf32); + + /// Trims leading and trailing whitespace from the input string. + /// @param s The string to trim. + /// @return The trimmed string. std::string trim(const std::string &s); + + /// Converts a timestamp to a string format. + /// @param t The timestamp in 10 ms units. + /// @param comma If true, use a comma as the decimal separator; otherwise, use + /// a period. + /// @return The formatted timestamp as a string. std::string timestamp_to_str(int64_t t, bool comma = false); + /// Sets a grammar for transcription with a starting rule and penalty. + /// @param grammar The grammar rules as a string. + /// @param start_rule The starting rule for the grammar. + /// @param grammar_penalty A penalty factor for grammar violations. + /// @return True if the grammar is set successfully; false otherwise. bool set_grammar(const std::string grammar, const std::string start_rule, float grammar_penalty); + + /// Resets the grammar to its default state. void reset_grammar(); + + /// Sets an initial prompt for transcription. + /// @param prompt The initial prompt text. void set_init_prompt(const std::string prompt); + + /// Resets the initial prompt to its default state. void reset_init_prompt(); protected: + /// Number of processors used for parallel processing. int n_processors; + + /// Parameters used for full transcription tasks. struct whisper_full_params wparams; + /// The Whisper context. struct whisper_context *ctx; + + /// Parsed grammar state. grammar_parser::parse_state grammar_parsed; + + /// Grammar rules derived from the parsed grammar. std::vector grammar_rules; }; } // namespace whisper_ros -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/whisper_ros/whisper_base_node.hpp b/whisper_ros/include/whisper_ros/whisper_base_node.hpp index aab1c12..03399b5 100644 --- a/whisper_ros/include/whisper_ros/whisper_base_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_base_node.hpp @@ -35,39 +35,108 @@ namespace whisper_ros { using CallbackReturn = rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn; +/** + * @class WhisperBaseNode + * @brief A ROS 2 Lifecycle Node for interfacing with the Whisper speech-to-text + * system. + */ class WhisperBaseNode : public rclcpp_lifecycle::LifecycleNode { public: + /** + * @brief Constructs a new WhisperBaseNode instance. + */ WhisperBaseNode(); + /** + * @brief Configures the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_configure(const rclcpp_lifecycle::State &); + on_configure(const rclcpp_lifecycle::State &state); + + /** + * @brief Activates the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_activate(const rclcpp_lifecycle::State &); + on_activate(const rclcpp_lifecycle::State &state); + + /** + * @brief Deactivates the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_deactivate(const rclcpp_lifecycle::State &); + on_deactivate(const rclcpp_lifecycle::State &state); + + /** + * @brief Cleans up resources during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_cleanup(const rclcpp_lifecycle::State &); + on_cleanup(const rclcpp_lifecycle::State &state); + + /** + * @brief Shuts down the node during its lifecycle. + * + * @param state The current state of the lifecycle node. + * @return CallbackReturn indicating success or failure. + */ rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn - on_shutdown(const rclcpp_lifecycle::State &); + on_shutdown(const rclcpp_lifecycle::State &state); + /** + * @brief Activates ROS 2 interfaces specific to this node. To be implemented + * in derived classes. + */ virtual void activate_ros_interfaces(){}; + + /** + * @brief Deactivates ROS 2 interfaces specific to this node. To be + * implemented in derived classes. + */ virtual void deactivate_ros_interfaces(){}; protected: + /// @brief The language for transcription (e.g., "en"). std::string language; + + /// @brief Shared pointer to the Whisper speech-to-text processor. std::shared_ptr whisper; + /** + * @brief Transcribes a given audio input to text. + * + * @param audio A vector of floating-point audio samples. + * @return A Transcription message containing the text and metadata. + */ whisper_msgs::msg::Transcription transcribe(const std::vector &audio); private: + /// @brief The Whisper model to use. std::string model; + + /// @brief Device identifier for OpenVINO encoding (e.g., "CPU"). std::string openvino_encode_device; + + /// @brief Number of processors to use. int n_processors; + + /// @brief Whisper context parameters. struct whisper_context_params cparams = whisper_context_default_params(); + + /// @brief Whisper full transcription parameters. struct whisper_full_params wparams; }; } // namespace whisper_ros -#endif \ No newline at end of file +#endif // WHISPER_ROS__WHISPER_BASE_NODE_HPP diff --git a/whisper_ros/include/whisper_ros/whisper_node.hpp b/whisper_ros/include/whisper_ros/whisper_node.hpp index d6ea89b..73f53b8 100644 --- a/whisper_ros/include/whisper_ros/whisper_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_node.hpp @@ -35,36 +35,99 @@ namespace whisper_ros { +/** + * @class WhisperNode + * @brief A ROS 2 node that extends WhisperBaseNode to handle + * transcription-related functionality. + */ class WhisperNode : public WhisperBaseNode { public: + /** + * @brief Constructs a WhisperNode instance. + */ WhisperNode(); + /** + * @brief Activates ROS 2 interfaces such as publishers, subscriptions, and + * services. + */ void activate_ros_interfaces(); + + /** + * @brief Deactivates ROS 2 interfaces by resetting and clearing publishers, + * subscriptions, and services. + */ void deactivate_ros_interfaces(); private: + /// Publisher for transcription messages. rclcpp::Publisher::SharedPtr publisher_; + + /// Subscription for voice activity detection (VAD) data. rclcpp::Subscription::SharedPtr subscription_; + /// Service for setting grammar configuration. rclcpp::Service::SharedPtr set_grammar_service_; + + /// Service for resetting grammar configuration. rclcpp::Service::SharedPtr reset_grammar_service_; + + /// Service for setting the initial transcription prompt. rclcpp::Service::SharedPtr set_init_prompt_service_; + + /// Service for resetting the initial transcription prompt. rclcpp::Service::SharedPtr reset_init_prompt_service_; + /** + * @brief Callback for processing voice activity detection (VAD) messages. + * + * @param msg A shared pointer to the received Float32MultiArray message. + */ void vad_callback(const std_msgs::msg::Float32MultiArray::SharedPtr msg); + + /** + * @brief Callback for the SetGrammar service. + * + * @param request Shared pointer to the request containing grammar + * configuration. + * @param response Shared pointer to the response indicating success or + * failure. + */ void set_grammar_service_callback( const std::shared_ptr request, std::shared_ptr response); + + /** + * @brief Callback for the ResetGrammar service. + * + * @param request Shared pointer to the Empty request (unused). + * @param response Shared pointer to the Empty response (unused). + */ void reset_grammar_service_callback( const std::shared_ptr request, std::shared_ptr response); + + /** + * @brief Callback for the SetInitPrompt service. + * + * @param request Shared pointer to the request containing the initial prompt + * text. + * @param response Shared pointer to the response indicating success. + */ void set_init_prompt_service_callback( const std::shared_ptr request, std::shared_ptr response); + + /** + * @brief Callback for the ResetInitPrompt service. + * + * @param request Shared pointer to the Empty request (unused). + * @param response Shared pointer to the Empty response (unused). + */ void reset_init_prompt_service_callback( const std::shared_ptr request, std::shared_ptr response); @@ -72,4 +135,4 @@ class WhisperNode : public WhisperBaseNode { } // namespace whisper_ros -#endif \ No newline at end of file +#endif diff --git a/whisper_ros/include/whisper_ros/whisper_server_node.hpp b/whisper_ros/include/whisper_ros/whisper_server_node.hpp index 555d62d..6adbae7 100644 --- a/whisper_ros/include/whisper_ros/whisper_server_node.hpp +++ b/whisper_ros/include/whisper_ros/whisper_server_node.hpp @@ -37,41 +37,101 @@ namespace whisper_ros { +/** + * @class WhisperServerNode + * @brief A node providing speech-to-text (STT) functionality using Whisper and + * Silero VAD. + */ class WhisperServerNode : public WhisperBaseNode { using STT = whisper_msgs::action::STT; using GoalHandleSTT = rclcpp_action::ServerGoalHandle; public: + /** + * @brief Constructor for the WhisperServerNode class. + */ WhisperServerNode(); + /** + * @brief Activates ROS interfaces including subscriptions, action servers, + * and clients. + */ void activate_ros_interfaces(); + + /** + * @brief Deactivates ROS interfaces by releasing subscriptions, action + * servers, and clients. + */ void deactivate_ros_interfaces(); protected: + /** + * @brief Enables or disables Silero VAD. + * @param enable If true, enables Silero VAD; otherwise, disables it. + */ void enable_silero(bool enable); private: + /// The transcription message containing the converted text. whisper_msgs::msg::Transcription transcription_msg; + + /// Mutex for synchronizing access to the transcription message. std::mutex transcription_mutex; + + /// Condition variable for waiting on transcription updates. std::condition_variable transcription_cond; + /// Handle for the active STT goal. std::shared_ptr goal_handle_; + + /// Subscription to receive VAD data. rclcpp::Subscription::SharedPtr subscription_; + + /// Client to enable or disable Silero VAD. rclcpp::Client::SharedPtr enable_silero_client_; + + /// Action server for handling STT requests. rclcpp_action::Server::SharedPtr action_server_; + /** + * @brief Callback for processing VAD data. + * @param msg The incoming VAD data message. + */ void vad_callback(const std_msgs::msg::Float32MultiArray::SharedPtr msg); + + /** + * @brief Executes the STT goal. + * @param goal_handle The goal handle representing the active STT request. + */ void execute(const std::shared_ptr goal_handle); + + /** + * @brief Handles an incoming goal request. + * @param uuid The unique identifier of the goal. + * @param goal The goal details. + * @return Response indicating whether the goal is accepted or rejected. + */ rclcpp_action::GoalResponse handle_goal(const rclcpp_action::GoalUUID &uuid, std::shared_ptr goal); + + /** + * @brief Handles a cancellation request for an active goal. + * @param goal_handle The handle of the goal to be canceled. + * @return Response indicating whether the cancellation is accepted. + */ rclcpp_action::CancelResponse handle_cancel(const std::shared_ptr goal_handle); + + /** + * @brief Handles acceptance of a goal and starts execution. + * @param goal_handle The handle of the accepted goal. + */ void handle_accepted(const std::shared_ptr goal_handle); }; } // namespace whisper_ros -#endif \ No newline at end of file +#endif From eb6fae69b6666d088fe1cb82cbe903c83ac0b12c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 12:33:44 +0100 Subject: [PATCH 17/18] whisper logs added to vad_iterator --- whisper_ros/CMakeLists.txt | 1 + whisper_ros/src/silero_vad/vad_iterator.cpp | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/whisper_ros/CMakeLists.txt b/whisper_ros/CMakeLists.txt index 22332cd..7c2ecc7 100644 --- a/whisper_ros/CMakeLists.txt +++ b/whisper_ros/CMakeLists.txt @@ -72,6 +72,7 @@ add_executable(silero_vad_node src/silero_vad/silero_vad_node.cpp src/silero_vad/vad_iterator.cpp src/silero_vad/timestamp.cpp + src/whisper_utils/logs.cpp ) target_link_libraries(silero_vad_node ${PORTAUDIO_LIB}) ament_target_dependencies(silero_vad_node diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index c6669a0..1499bea 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -27,6 +27,7 @@ #include #include "silero_vad/vad_iterator.hpp" +#include "whisper_utils/logs.hpp" using namespace silero_vad; @@ -48,9 +49,11 @@ VadIterator::VadIterator(const std::string &model_path, int sample_rate, try { this->init_onnx_model(model_path); } catch (const std::exception &e) { - throw std::runtime_error("Failed to initialize ONNX model: " + - std::string(e.what())); + WHISPER_LOG_ERROR("Failed to initialize ONNX model: %s", e.what()); + return; } + + WHISPER_LOG_INFO("SileroVAD Iterator started"); } void VadIterator::init_onnx_model(const std::string &model_path) { @@ -77,6 +80,9 @@ void VadIterator::reset_states() { } Timestamp VadIterator::predict(const std::vector &data) { + + WHISPER_LOG_INFO("Processing audio data"); + // Pre-fill input with context this->input.clear(); this->input.insert(this->input.end(), this->context.begin(), @@ -107,7 +113,8 @@ Timestamp VadIterator::predict(const std::vector &data) { this->ort_inputs.data(), this->ort_inputs.size(), this->output_node_names.data(), this->output_node_names.size()); } catch (const std::exception &e) { - throw std::runtime_error("ONNX inference failed: " + std::string(e.what())); + WHISPER_LOG_ERROR("ONNX inference failed: %s", e.what()); + return Timestamp(-1, -1, 0.0f); } // Process output @@ -115,6 +122,7 @@ Timestamp VadIterator::predict(const std::vector &data) { float *updated_state = this->ort_outputs[1].GetTensorMutableData(); std::copy(updated_state, updated_state + this->state.size(), this->state.begin()); + WHISPER_LOG_DEBUG("Speech probability %f", speech_prob); // Update context with the last 64 samples of data this->context.assign(data.end() - context_size, data.end()); @@ -131,6 +139,7 @@ Timestamp VadIterator::predict(const std::vector &data) { int start_timestwamp = this->current_sample - this->speech_pad_samples - this->window_size_samples; this->triggered = true; + WHISPER_LOG_DEBUG("Speech starts at %d", start_timestwamp); return Timestamp(start_timestwamp, -1, speech_prob); } } @@ -145,6 +154,7 @@ Timestamp VadIterator::predict(const std::vector &data) { this->temp_end + this->speech_pad_samples - this->window_size_samples; this->triggered = false; this->temp_end = 0; + WHISPER_LOG_DEBUG("Speech ends at %d", end_timestamp); return Timestamp(-1, end_timestamp, speech_prob); } } From e16e9ab55ac7754c048879fc2e7c3fe9552d4ac6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Dec 2024 12:45:46 +0100 Subject: [PATCH 18/18] adding set whisper log level --- whisper_ros/include/whisper_utils/logs.hpp | 49 ++++++++++++++++++++- whisper_ros/src/silero_vad/vad_iterator.cpp | 2 +- whisper_ros/src/whisper_utils/logs.cpp | 5 +++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/whisper_ros/include/whisper_utils/logs.hpp b/whisper_ros/include/whisper_utils/logs.hpp index b2b6b98..b50c83c 100644 --- a/whisper_ros/include/whisper_utils/logs.hpp +++ b/whisper_ros/include/whisper_utils/logs.hpp @@ -45,6 +45,37 @@ extern LogFunction log_warn; ///< Pointer to the warning logging function extern LogFunction log_info; ///< Pointer to the info logging function extern LogFunction log_debug; ///< Pointer to the debug logging function +/** + * @brief Enum representing different log levels for controlling log verbosity. + * + * This enum defines the severity levels of logs that can be used to control + * which log messages should be displayed. The levels are ordered from most + * severe to least severe. Only logs at or above the current log level will be + * shown. + */ +enum LogLevel { + /// Log level for error messages. Only critical errors should be logged. + ERROR = 0, + /// Log level for warning messages. Indicate potential issues that are not + /// critical. + WARN, + /// Log level for informational messages. General runtime information about + /// the system's state. + INFO, + /// Log level for debug messages. Used for detailed information, mainly for + /// developers. + DEBUG +}; + +/** + * @brief The current log level for the application. + * + * This global variable holds the current log level, which determines the + * verbosity of the logs. Logs at or above this level will be displayed. The + * default level is set to INFO. + */ +extern LogLevel log_level; + /** * @brief Extracts the filename from a given file path. * @@ -61,20 +92,36 @@ inline const char *extract_filename(const char *path) { return filename ? filename + 1 : path; } -// Macros for logging with automatic file and function information #define WHISPER_LOG_ERROR(text, ...) \ + if (whisper_utils::log_level >= whisper_utils::ERROR) \ whisper_utils::log_error(whisper_utils::extract_filename(__FILE__), \ __FUNCTION__, __LINE__, text, ##__VA_ARGS__) + #define WHISPER_LOG_WARN(text, ...) \ + if (whisper_utils::log_level >= whisper_utils::WARN) \ whisper_utils::log_warn(whisper_utils::extract_filename(__FILE__), \ __FUNCTION__, __LINE__, text, ##__VA_ARGS__) + #define WHISPER_LOG_INFO(text, ...) \ + if (whisper_utils::log_level >= whisper_utils::INFO) \ whisper_utils::log_info(whisper_utils::extract_filename(__FILE__), \ __FUNCTION__, __LINE__, text, ##__VA_ARGS__) + #define WHISPER_LOG_DEBUG(text, ...) \ + if (whisper_utils::log_level >= whisper_utils::DEBUG) \ whisper_utils::log_debug(whisper_utils::extract_filename(__FILE__), \ __FUNCTION__, __LINE__, text, ##__VA_ARGS__) +/** + * @brief Sets the log level for the logs. + * + * This function allows the user to specify the log level error, warning, info, + * or debug. + * + * @param log_level Log level. + */ +void set_log_level(LogLevel log_level); + } // namespace whisper_utils #endif // WHISPER_UTILS__LOGS_HPP \ No newline at end of file diff --git a/whisper_ros/src/silero_vad/vad_iterator.cpp b/whisper_ros/src/silero_vad/vad_iterator.cpp index 1499bea..f850986 100644 --- a/whisper_ros/src/silero_vad/vad_iterator.cpp +++ b/whisper_ros/src/silero_vad/vad_iterator.cpp @@ -81,7 +81,7 @@ void VadIterator::reset_states() { Timestamp VadIterator::predict(const std::vector &data) { - WHISPER_LOG_INFO("Processing audio data"); + WHISPER_LOG_DEBUG("Processing audio data"); // Pre-fill input with context this->input.clear(); diff --git a/whisper_ros/src/whisper_utils/logs.cpp b/whisper_ros/src/whisper_utils/logs.cpp index de68752..1b59752 100644 --- a/whisper_ros/src/whisper_utils/logs.cpp +++ b/whisper_ros/src/whisper_utils/logs.cpp @@ -115,4 +115,9 @@ LogFunction log_warn = default_log_warn; LogFunction log_info = default_log_info; LogFunction log_debug = default_log_debug; +// Initialize the log level to INFO +LogLevel log_level = INFO; + +void set_log_level(LogLevel log_level) { log_level = log_level; } + } // namespace whisper_utils \ No newline at end of file