diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7ed20b2..7ee6981 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,10 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Changelog for package ROS 2 Whisper ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +1.1.0 (2023-09-01) +* `whisper_demos`: Improved terminal output +* `whisper_server`: Improved state machine + 1.0.0 (2023-08-31) ------------------ * Initial release diff --git a/audio_listener/audio_listener/audio_listener.py b/audio_listener/audio_listener/audio_listener.py index 6f50d2c..1d97b1e 100644 --- a/audio_listener/audio_listener/audio_listener.py +++ b/audio_listener/audio_listener/audio_listener.py @@ -4,7 +4,6 @@ import pyaudio import rclpy from rclpy.node import Node -from rclpy.qos import qos_profile_system_default from std_msgs.msg import Int16MultiArray, MultiArrayDimension @@ -16,7 +15,7 @@ def __init__(self, node_name: str) -> None: namespace="", parameters=[ ("channels", 1), - ("frames_per_buffer", 1024), + ("frames_per_buffer", 1000), ("rate", 16000), ], ) @@ -39,7 +38,7 @@ def __init__(self, node_name: str) -> None: ) self.audio_publisher_ = self.create_publisher( - Int16MultiArray, "~/audio", qos_profile_system_default + Int16MultiArray, "~/audio", 5 ) self.audio_publisher_timer_ = self.create_timer( diff --git a/audio_listener/package.xml b/audio_listener/package.xml index 50fd27e..1173439 100644 --- a/audio_listener/package.xml +++ b/audio_listener/package.xml @@ -2,7 +2,7 @@ audio_listener - 1.0.0 + 1.1.0 Audio common replica. mhubii MIT diff --git a/audio_listener/setup.py b/audio_listener/setup.py index ee443b7..78dde99 100644 --- a/audio_listener/setup.py +++ b/audio_listener/setup.py @@ -4,7 +4,7 @@ setup( name=package_name, - version="1.0.0", + version="1.1.0", packages=find_packages(exclude=["test"]), data_files=[ ("share/ament_index/resource_index/packages", ["resource/" + package_name]), diff --git a/whisper_bringup/launch/bringup.launch.py b/whisper_bringup/launch/bringup.launch.py index 8b5312c..80c0eff 100644 --- a/whisper_bringup/launch/bringup.launch.py +++ b/whisper_bringup/launch/bringup.launch.py @@ -19,6 +19,9 @@ def generate_launch_description() -> LaunchDescription: ld.add_action(WhisperServerMixin.arg_model_name()) ld.add_action(WhisperServerMixin.arg_n_threads()) ld.add_action(WhisperServerMixin.arg_language()) + ld.add_action(WhisperServerMixin.arg_batch_capacity()) + ld.add_action(WhisperServerMixin.arg_buffer_capacity()) + ld.add_action(WhisperServerMixin.arg_carry_over_capacity()) ld.add_action( WhisperServerMixin.composable_node_container( composable_node_descriptions=[ @@ -27,6 +30,9 @@ def generate_launch_description() -> LaunchDescription: WhisperServerMixin.param_model_name(), WhisperServerMixin.param_n_threads(), WhisperServerMixin.param_language(), + WhisperServerMixin.param_batch_capacity(), + WhisperServerMixin.param_buffer_capacity(), + WhisperServerMixin.param_carry_over_capacity(), ], remappings=[("/whisper/audio", "/audio_listener/audio")], namespace="whisper", diff --git a/whisper_bringup/package.xml b/whisper_bringup/package.xml index 5b3e24c..dea8663 100644 --- a/whisper_bringup/package.xml +++ b/whisper_bringup/package.xml @@ -2,7 +2,7 @@ whisper_bringup - 1.0.0 + 1.1.0 TODO: Package description mhubii MIT diff --git a/whisper_cpp_vendor/package.xml b/whisper_cpp_vendor/package.xml index 8f8a631..af45c55 100644 --- a/whisper_cpp_vendor/package.xml +++ b/whisper_cpp_vendor/package.xml @@ -2,7 +2,7 @@ whisper_cpp_vendor - 1.0.0 + 1.1.0 Vendor package for whisper.cpp. mhubii MIT diff --git a/whisper_demos/package.xml b/whisper_demos/package.xml index c6f7d66..7e8a69a 100644 --- a/whisper_demos/package.xml +++ b/whisper_demos/package.xml @@ -2,7 +2,7 @@ whisper_demos - 1.0.0 + 1.1.0 Demos for using the ROS 2 whisper package. mhubii MIT diff --git a/whisper_demos/setup.py b/whisper_demos/setup.py index db694f3..5a1f143 100644 --- a/whisper_demos/setup.py +++ b/whisper_demos/setup.py @@ -4,7 +4,7 @@ setup( name=package_name, - version="1.0.0", + version="1.1.0", packages=find_packages(exclude=["test"]), data_files=[ ("share/ament_index/resource_index/packages", ["resource/" + package_name]), diff --git a/whisper_demos/whisper_demos/whisper_on_key.py b/whisper_demos/whisper_demos/whisper_on_key.py index 2d4312e..f1849d3 100644 --- a/whisper_demos/whisper_demos/whisper_on_key.py +++ b/whisper_demos/whisper_demos/whisper_on_key.py @@ -1,18 +1,22 @@ +import sys + import rclpy from builtin_interfaces.msg import Duration from pynput.keyboard import Key, Listener from rclpy.action import ActionClient from rclpy.node import Node from rclpy.task import Future -from whisper_msgs.action import Inference from whisper_msgs.action._inference import Inference_FeedbackMessage +from whisper_msgs.action import Inference + class WhisperOnKey(Node): def __init__(self, node_name: str) -> None: super().__init__(node_name=node_name) # whisper + self.batch_idx = -1 self.whisper_client = ActionClient(self, Inference, "/whisper/inference") while not self.whisper_client.wait_for_server(1): @@ -63,10 +67,15 @@ def on_goal_accepted(self, future: Future) -> None: def on_done(self, future: Future) -> None: result: Inference.Result = future.result().result - self.get_logger().info(f"Result: {result.text}") + self.get_logger().info(f"Result: {result.transcriptions}") def on_feedback(self, feedback_msg: Inference_FeedbackMessage) -> None: - self.get_logger().info(f"{feedback_msg.feedback.text}") + if self.batch_idx != feedback_msg.feedback.batch_idx: + print("") + self.batch_idx = feedback_msg.feedback.batch_idx + sys.stdout.write("\033[K") + print(f"{feedback_msg.feedback.transcription}") + sys.stdout.write("\033[F") def info_string(self) -> str: return ( diff --git a/whisper_msgs/action/Inference.action b/whisper_msgs/action/Inference.action index 34f0f16..46cb6f0 100644 --- a/whisper_msgs/action/Inference.action +++ b/whisper_msgs/action/Inference.action @@ -1,7 +1,7 @@ builtin_interfaces/Duration max_duration --- string info -string[] text +string[] transcriptions --- uint16 batch_idx -string text +string transcription diff --git a/whisper_msgs/package.xml b/whisper_msgs/package.xml index e1c6aca..3edc307 100644 --- a/whisper_msgs/package.xml +++ b/whisper_msgs/package.xml @@ -2,7 +2,7 @@ whisper_msgs - 1.0.0 + 1.1.0 Messages for the ROS 2 whisper package mhubii MIT diff --git a/whisper_server/config/whisper.yaml b/whisper_server/config/whisper.yaml index 0de05b2..ef04f03 100644 --- a/whisper_server/config/whisper.yaml +++ b/whisper_server/config/whisper.yaml @@ -1,6 +1,12 @@ whisper: ros__parameters: - model_name: "base.en" # other models https://huggingface.co/ggerganov/whisper.cpp + # whisper + model_name: "tiny.en" # other models https://huggingface.co/ggerganov/whisper.cpp language: "en" n_threads: 4 print_progress: false + + # buffer + batch_capacity: 6 # seconds + buffer_capacity: 2 # seconds + carry_over_capacity: 200 # milliseconds diff --git a/whisper_server/include/whisper_server/inference_node.hpp b/whisper_server/include/whisper_server/inference_node.hpp index b455208..5d1594a 100644 --- a/whisper_server/include/whisper_server/inference_node.hpp +++ b/whisper_server/include/whisper_server/inference_node.hpp @@ -1,7 +1,6 @@ #ifndef WHISPER_NODES__INFERENCE_NODE_HPP_ #define WHISPER_NODES__INFERENCE_NODE_HPP_ -#include #include #include #include @@ -48,14 +47,13 @@ class InferenceNode { void on_inference_accepted_(const std::shared_ptr goal_handle); std::string inference_(const std::vector &audio); rclcpp::Time inference_start_time_; - std::atomic_bool running_inference_; // whisper - void initialize_whisper_(); - ModelManager model_manager_; - BatchedBuffer batched_buffer_; - Whisper whisper_; + std::unique_ptr model_manager_; + std::unique_ptr batched_buffer_; + std::unique_ptr whisper_; std::string language_; + void initialize_whisper_(); }; } // end of namespace whisper #endif // WHISPER_NODES__INFERENCE_NODE_HPP_ diff --git a/whisper_server/package.xml b/whisper_server/package.xml index c7c9379..62e32eb 100644 --- a/whisper_server/package.xml +++ b/whisper_server/package.xml @@ -2,7 +2,7 @@ whisper_server - 1.0.0 + 1.1.0 ROS 2 whisper.cpp inference server. mhubii MIT diff --git a/whisper_server/src/inference_node.cpp b/whisper_server/src/inference_node.cpp index 59618d6..c8e34b8 100644 --- a/whisper_server/src/inference_node.cpp +++ b/whisper_server/src/inference_node.cpp @@ -2,7 +2,7 @@ namespace whisper { InferenceNode::InferenceNode(const rclcpp::Node::SharedPtr node_ptr) - : node_ptr_(node_ptr), running_inference_(false), language_("en") { + : node_ptr_(node_ptr), language_("en") { declare_parameters_(); auto cb_group = node_ptr_->create_callback_group(rclcpp::CallbackGroupType::Reentrant); @@ -11,7 +11,7 @@ InferenceNode::InferenceNode(const rclcpp::Node::SharedPtr node_ptr) // audio subscription audio_sub_ = node_ptr_->create_subscription( - "audio", 10, std::bind(&InferenceNode::on_audio_, this, std::placeholders::_1), options); + "audio", 5, std::bind(&InferenceNode::on_audio_, this, std::placeholders::_1), options); // inference action server inference_action_server_ = rclcpp_action::create_server( @@ -24,12 +24,25 @@ InferenceNode::InferenceNode(const rclcpp::Node::SharedPtr node_ptr) on_parameter_set_handle_ = node_ptr_->add_on_set_parameters_callback( std::bind(&InferenceNode::on_parameter_set_, this, std::placeholders::_1)); - // initialize model + // whisper + model_manager_ = std::make_unique(); + batched_buffer_ = std::make_unique( + std::chrono::seconds(node_ptr_->get_parameter("batch_capacity").as_int()), + std::chrono::seconds(node_ptr_->get_parameter("buffer_capacity").as_int()), + std::chrono::milliseconds(node_ptr_->get_parameter("carry_over_capacity").as_int())); + whisper_ = std::make_unique(); + initialize_whisper_(); } void InferenceNode::declare_parameters_() { - node_ptr_->declare_parameter("model_name", "base.en"); + // buffer parameters + node_ptr_->declare_parameter("batch_capacity", 6); + node_ptr_->declare_parameter("buffer_capacity", 2); + node_ptr_->declare_parameter("carry_over_capacity", 200); + + // whisper parameters + node_ptr_->declare_parameter("model_name", "tiny.en"); // consider other parameters: // https://github.com/ggerganov/whisper.cpp/blob/a4bb2df36aeb4e6cfb0c1ca9fbcf749ef39cc852/whisper.h#L351 node_ptr_->declare_parameter("language", "en"); @@ -41,10 +54,10 @@ void InferenceNode::initialize_whisper_() { std::string model_name = node_ptr_->get_parameter("model_name").as_string(); RCLCPP_INFO(node_ptr_->get_logger(), "Checking whether model %s is available...", model_name.c_str()); - if (!model_manager_.is_available(model_name)) { + if (!model_manager_->is_available(model_name)) { RCLCPP_INFO(node_ptr_->get_logger(), "Model %s is not available. Attempting download...", model_name.c_str()); - if (model_manager_.make_available(model_name) != 0) { + if (model_manager_->make_available(model_name) != 0) { std::string err_msg = "Failed to download model " + model_name + "."; RCLCPP_ERROR(node_ptr_->get_logger(), err_msg.c_str()); throw std::runtime_error(err_msg); @@ -54,13 +67,13 @@ void InferenceNode::initialize_whisper_() { RCLCPP_INFO(node_ptr_->get_logger(), "Model %s is available.", model_name.c_str()); RCLCPP_INFO(node_ptr_->get_logger(), "Initializing model %s...", model_name.c_str()); - whisper_.initialize(model_manager_.get_model_path(model_name)); + whisper_->initialize(model_manager_->get_model_path(model_name)); RCLCPP_INFO(node_ptr_->get_logger(), "Model %s initialized.", model_name.c_str()); language_ = node_ptr_->get_parameter("language").as_string(); - whisper_.params.language = language_.c_str(); - whisper_.params.n_threads = node_ptr_->get_parameter("n_threads").as_int(); - whisper_.params.print_progress = node_ptr_->get_parameter("print_progress").as_bool(); + whisper_->params.language = language_.c_str(); + whisper_->params.n_threads = node_ptr_->get_parameter("n_threads").as_int(); + whisper_->params.print_progress = node_ptr_->get_parameter("print_progress").as_bool(); } rcl_interfaces::msg::SetParametersResult @@ -68,9 +81,9 @@ InferenceNode::on_parameter_set_(const std::vector ¶meter rcl_interfaces::msg::SetParametersResult result; for (const auto ¶meter : parameters) { if (parameter.get_name() == "n_threads") { - whisper_.params.n_threads = parameter.as_int(); + whisper_->params.n_threads = parameter.as_int(); RCLCPP_INFO(node_ptr_->get_logger(), "Parameter %s set to %d.", parameter.get_name().c_str(), - whisper_.params.n_threads); + whisper_->params.n_threads); continue; } result.reason = "Parameter " + parameter.get_name() + " not handled."; @@ -82,57 +95,72 @@ InferenceNode::on_parameter_set_(const std::vector ¶meter } void InferenceNode::on_audio_(const std_msgs::msg::Int16MultiArray::SharedPtr msg) { - batched_buffer_.enqueue(msg->data); + batched_buffer_->enqueue(msg->data); } rclcpp_action::GoalResponse -InferenceNode::on_inference_(const rclcpp_action::GoalUUID &uuid, - std::shared_ptr goal) { +InferenceNode::on_inference_(const rclcpp_action::GoalUUID & /*uuid*/, + std::shared_ptr /*goal*/) { RCLCPP_INFO(node_ptr_->get_logger(), "Received inference request."); - if (running_inference_) { - RCLCPP_WARN(node_ptr_->get_logger(), "Inference already running."); - return rclcpp_action::GoalResponse::REJECT; - } - return rclcpp_action::GoalResponse::ACCEPT_AND_EXECUTE; } rclcpp_action::CancelResponse -InferenceNode::on_cancel_inference_(const std::shared_ptr goal_handle) { +InferenceNode::on_cancel_inference_(const std::shared_ptr /*goal_handle*/) { RCLCPP_INFO(node_ptr_->get_logger(), "Cancelling inference..."); return rclcpp_action::CancelResponse::ACCEPT; } void InferenceNode::on_inference_accepted_(const std::shared_ptr goal_handle) { RCLCPP_INFO(node_ptr_->get_logger(), "Starting inference..."); - running_inference_ = true; - auto result = std::make_shared(); auto feedback = std::make_shared(); + auto result = std::make_shared(); + inference_start_time_ = node_ptr_->now(); + + while (rclcpp::ok()) { + if (node_ptr_->now() - inference_start_time_ > goal_handle->get_goal()->max_duration) { + result->info = "Inference timed out."; + RCLCPP_INFO(node_ptr_->get_logger(), result->info.c_str()); + goal_handle->succeed(result); + batched_buffer_->clear(); + return; + } - auto loop_start_time = node_ptr_->now(); + if (goal_handle->is_canceling()) { + result->info = "Inference cancelled."; + RCLCPP_INFO(node_ptr_->get_logger(), result->info.c_str()); + goal_handle->canceled(result); + batched_buffer_->clear(); + return; + } - while (rclcpp::ok() && - node_ptr_->now() - loop_start_time < goal_handle->get_goal()->max_duration) { // run inference - auto text = inference_(batched_buffer_.dequeue()); + auto transcription = inference_(batched_buffer_->dequeue()); - // feedback data results - if (feedback->batch_idx != batched_buffer_.batch_idx()) { - result->text.push_back(feedback->text); - } - feedback->text = text; - feedback->batch_idx = batched_buffer_.batch_idx(); + // feedback to client + feedback->transcription = transcription; + feedback->batch_idx = batched_buffer_->batch_idx(); goal_handle->publish_feedback(feedback); + + // update inference result + if (result->transcriptions.size() == batched_buffer_->batch_idx() + 1) { + result->transcriptions[result->transcriptions.size() - 1] = feedback->transcription; + } else { + result->transcriptions.push_back(feedback->transcription); + } } - running_inference_ = false; - goal_handle->succeed(result); - batched_buffer_.clear(); + if (rclcpp::ok()) { + result->info = "Inference succeeded."; + RCLCPP_INFO(node_ptr_->get_logger(), result->info.c_str()); + goal_handle->succeed(result); + batched_buffer_->clear(); + } } std::string InferenceNode::inference_(const std::vector &audio) { auto inference_start_time = node_ptr_->now(); - auto text = whisper_.forward(audio); + auto transcription = whisper_->forward(audio); auto inference_duration = (node_ptr_->now() - inference_start_time).to_chrono(); if (inference_duration > whisper::count_to_time(audio.size())) { @@ -140,6 +168,6 @@ std::string InferenceNode::inference_(const std::vector &audio) { "Inference took longer than audio buffer size. This leads to un-inferenced audio " "data. Consider increasing thread number or compile with accelerator support."); } - return text; + return transcription; } } // end of namespace whisper diff --git a/whisper_server/whisper_server_launch_mixin/whisper_server_mixin.py b/whisper_server/whisper_server_launch_mixin/whisper_server_mixin.py index 815ae9c..b404422 100644 --- a/whisper_server/whisper_server_launch_mixin/whisper_server_mixin.py +++ b/whisper_server/whisper_server_launch_mixin/whisper_server_mixin.py @@ -11,12 +11,12 @@ class InferenceMixin: def arg_model_name() -> DeclareLaunchArgument: return DeclareLaunchArgument( name="model_name", - default_value="base.en", + default_value="tiny.en", description="Model name for whisper.cpp. Refer to https://huggingface.co/ggerganov/whisper.cpp.", choices=[ "tiny.en", "tiny", - "base.en", + "tiny.en", "base", "small.en", "small", @@ -44,9 +44,33 @@ def arg_language() -> DeclareLaunchArgument: choices=["en", "auto"], ) + @staticmethod + def arg_batch_capacity() -> DeclareLaunchArgument: + return DeclareLaunchArgument( + name="batch_capacity", + default_value="6", + description="Batch capacity in seconds.", + ) + + @staticmethod + def arg_buffer_capacity() -> DeclareLaunchArgument: + return DeclareLaunchArgument( + name="buffer_capacity", + default_value="2", + description="Buffer capacity in seconds.", + ) + + @staticmethod + def arg_carry_over_capacity() -> DeclareLaunchArgument: + return DeclareLaunchArgument( + name="carry_over_capacity", + default_value="200", + description="Carry over capacity in milliseconds.", + ) + @staticmethod def param_model_name() -> Dict[str, LaunchConfiguration]: - return {"model_name": LaunchConfiguration("model_name", default="base.en")} + return {"model_name": LaunchConfiguration("model_name", default="tiny.en")} @staticmethod def param_n_threads() -> Dict[str, LaunchConfiguration]: @@ -56,6 +80,22 @@ def param_n_threads() -> Dict[str, LaunchConfiguration]: def param_language() -> Dict[str, LaunchConfiguration]: return {"language": LaunchConfiguration("language", default="en")} + @staticmethod + def param_batch_capacity() -> Dict[str, LaunchConfiguration]: + return {"batch_capacity": LaunchConfiguration("batch_capacity", default="6")} + + @staticmethod + def param_buffer_capacity() -> Dict[str, LaunchConfiguration]: + return {"buffer_capacity": LaunchConfiguration("buffer_capacity", default="2")} + + @staticmethod + def param_carry_over_capacity() -> Dict[str, LaunchConfiguration]: + return { + "carry_over_capacity": LaunchConfiguration( + "carry_over_capacity", default="200" + ) + } + @staticmethod def composable_node_inference(**kwargs) -> ComposableNode: return ComposableNode( diff --git a/whisper_util/include/whisper_util/audio_buffers.hpp b/whisper_util/include/whisper_util/audio_buffers.hpp index 603ec13..26140df 100644 --- a/whisper_util/include/whisper_util/audio_buffers.hpp +++ b/whisper_util/include/whisper_util/audio_buffers.hpp @@ -62,7 +62,7 @@ template class RingBuffer { class BatchedBuffer { public: BatchedBuffer( - const std::chrono::milliseconds &batch_capacity = std::chrono::seconds(10), + const std::chrono::milliseconds &batch_capacity = std::chrono::seconds(6), const std::chrono::milliseconds &buffer_capacity = std::chrono::seconds(2), const std::chrono::milliseconds &carry_over_capacity = std::chrono::milliseconds(200)); diff --git a/whisper_util/include/whisper_util/model_manager.hpp b/whisper_util/include/whisper_util/model_manager.hpp index 33b6c8e..adfb420 100644 --- a/whisper_util/include/whisper_util/model_manager.hpp +++ b/whisper_util/include/whisper_util/model_manager.hpp @@ -15,9 +15,9 @@ class ModelManager { const std::string &cache_path = std::string(std::getenv("HOME")) + "/.cache/whisper.cpp"); void mkdir(const std::string &path); - bool is_available(const std::string &model_name = "base.en"); - int make_available(const std::string &model_name = "base.en"); - std::string get_model_path(const std::string &model_name = "base.en"); + bool is_available(const std::string &model_name = "tiny.en"); + int make_available(const std::string &model_name = "tiny.en"); + std::string get_model_path(const std::string &model_name = "tiny.en"); protected: std::string model_name_to_file_name_(const std::string &model_name); diff --git a/whisper_util/package.xml b/whisper_util/package.xml index e4cb2d7..e2af3e3 100644 --- a/whisper_util/package.xml +++ b/whisper_util/package.xml @@ -2,7 +2,7 @@ whisper_util - 1.0.0 + 1.1.0 ROS 2 wrapper for whisper.cpp. mhubii MIT diff --git a/whisper_util/src/audio_buffers.cpp b/whisper_util/src/audio_buffers.cpp index 2c68ca0..bec0bae 100644 --- a/whisper_util/src/audio_buffers.cpp +++ b/whisper_util/src/audio_buffers.cpp @@ -84,6 +84,7 @@ void BatchedBuffer::carry_over_() { void BatchedBuffer::clear() { std::lock_guard lock(mutex_); + batch_idx_ = 0; audio_.clear(); audio_buffer_.clear(); }