Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move automatic speech recognition agent to rai_core #357

Merged
merged 16 commits into from
Jan 24, 2025

Conversation

rachwalk
Copy link
Contributor

@rachwalk rachwalk commented Jan 7, 2025

Purpose

Currently the entirety of ASR is setup within rai_asr, whereas the functionality of handling the concepts related to speech recognition should be core of rai, as described in #309

Proposed Changes

This PR adds a ros2-independent agent which provides an API for building ASR system using models provided in rai_asr (or custom ones).

Issues

#309

Testing

For a minimal example to test the ASR workflow you can use this script:

import rclpy

from rai.agents import VoiceRecognitionAgent
from rai.communication import AudioInputDeviceConfig
from rai_asr.models import LocalWhisper, OpenAIWhisper, OpenWakeWord, SileroVAD

VAD_THRESHOLD = 0.1  # Note that this might be different depending on your device
OWW_THRESHOLD = 0.001  # Note that this might be different depending on your device

VAD_SAMPLING_RATE = 16000  # Or 8000
DEV_ID = 16  # check using `python -c "import sounddevice as sd; print(sd.query_devices(kind='input'))"`
DEVICE_SAMPLE_RATE = 44100  # Must be consistent with the DEV_ID sample rate
DEFAULT_BLOCKSIZE = 1280


microphone_configuration = AudioInputDeviceConfig(
    block_size=DEFAULT_BLOCKSIZE,
    consumer_sampling_rate=VAD_SAMPLING_RATE,
    dtype="int16",
    device_number=DEV_ID,
)
vad = SileroVAD(VAD_SAMPLING_RATE, VAD_THRESHOLD)
oww = OpenWakeWord("hey jarvis", OWW_THRESHOLD)
whisper = LocalWhisper("tiny", VAD_SAMPLING_RATE)
# whisper = OpenAIWhisper("whisper-1", VAD_SAMPLING_RATE, "en")

rclpy.init()
ros2_name = "rai_asr_agent"


agent = VoiceRecognitionAgent(DEV_ID, microphone_configuration, ros2_name, whisper, vad)
agent.add_detection_model(oww, pipeline="record")

agent.run()

idx = 0
while idx < 100_000_000_000:
    idx += 1
agent.stop()

Summary by CodeRabbit

  • New Features

    • Added voice recognition agent with support for microphone input and transcription
    • Introduced base classes for voice detection and transcription models
    • Implemented local and OpenAI Whisper transcription models
    • Added wake word and voice activity detection models
  • Bug Fixes

    • Updated pre-commit configuration to ignore additional linter errors
  • Refactor

    • Removed target_sampling_rate from audio input device configuration
    • Streamlined communication and audio processing components

Copy link
Member

@maciejmajek maciejmajek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid 🥇 . Wrote some comments, waiting for more

src/rai_asr/rai_asr/models/local_whisper.py Outdated Show resolved Hide resolved
src/rai/rai/agents/voice_agent.py Outdated Show resolved Hide resolved
src/rai_asr/rai_asr/models/base.py Outdated Show resolved Hide resolved
src/rai_asr/rai_asr/models/silero_vad.py Outdated Show resolved Hide resolved
src/rai/rai/agents/voice_agent.py Outdated Show resolved Hide resolved
src/rai_asr/rai_asr/models/base.py Show resolved Hide resolved
src/rai_asr/rai_asr/models/local_whisper.py Outdated Show resolved Hide resolved
src/rai_asr/rai_asr/models/open_ai_whisper.py Outdated Show resolved Hide resolved
src/rai/rai/agents/voice_agent.py Outdated Show resolved Hide resolved
src/rai/rai/agents/voice_agent.py Outdated Show resolved Hide resolved
@rachwalk rachwalk force-pushed the refactor/rai_asr branch 2 times, most recently from 06fdbca to 225178f Compare January 15, 2025 10:54
@boczekbartek
Copy link
Member

@rachwalk The snippet from Testing section doesn't work for me:
python -c "import sounddevice as sd; print(sd.query_devices(kind="input"))"
I think it needs escaping the \":
python -c "import sounddevice as sd; print(sd.query_devices(kind=\"input\"))"

@rachwalk rachwalk marked this pull request as ready for review January 23, 2025 15:04
Copy link
Member

@maciejmajek maciejmajek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far. Unfortunately, im not able to run this code today, so this possibly is the first batch of the review

src/rai/rai/communication/ros2/connectors.py Outdated Show resolved Hide resolved
src/rai_asr/rai_asr/models/local_whisper.py Show resolved Hide resolved
src/rai_asr/rai_asr/models/open_wake_word.py Show resolved Hide resolved
@maciejmajek
Copy link
Member

@coderabbitai full review

Copy link
Contributor

coderabbitai bot commented Jan 24, 2025

Walkthrough

This pull request introduces a comprehensive set of changes across multiple files in the rai and rai_asr packages. The modifications primarily focus on establishing a new framework for voice recognition and transcription agents, including base classes, specific model implementations, and communication connectors. The changes add support for different voice detection and transcription models, create a standardized interface for agents, and update configuration handling for audio input devices.

Changes

File Change Summary
.pre-commit-config.yaml Added E203 to flake8 ignored error codes
src/rai/rai/agents/__init__.py Added VoiceRecognitionAgent to __all__ list
src/rai/rai/agents/base.py Introduced BaseAgent abstract base class with connector management and abstract run method
src/rai/rai/agents/voice_agent.py Added VoiceRecognitionAgent with voice recognition and ROS2 integration capabilities
src/rai/rai/communication/__init__.py Added imports for ROS2ARIConnector, ROS2ARIMessage, and AudioInputDeviceConfig
src/rai/rai/communication/ros2/connectors.py Introduced ROS2ARIPayload TypedDict and updated ROS2ARIMessage constructor
src/rai/rai/communication/sound_device_connector.py Removed target_sampling_rate attribute and updated related logic
src/rai_asr/rai_asr/models/__init__.py Added imports for various transcription and voice detection models
src/rai_asr/rai_asr/models/base.py Introduced BaseVoiceDetectionModel and BaseTranscriptionModel abstract base classes
Multiple model files Added implementations for LocalWhisper, OpenAIWhisper, OpenWakeWord, and SileroVAD

Possibly related PRs

✨ Finishing Touches
  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🔭 Outside diff range comments (2)
src/rai_asr/rai_asr/asr_clients.py (2)

Line range hint 46-54: Improve API key handling and error message.

The current implementation could be enhanced to:

  1. Support API key rotation
  2. Provide a more helpful error message
  3. Allow injection of the API client for testing
     def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
         super().__init__(model_name, sample_rate, language)
         api_key = os.getenv("OPENAI_API_KEY")
         if api_key is None:
-            raise ValueError("OPENAI_API_KEY environment variable is not set.")
+            raise ValueError(
+                "OPENAI_API_KEY environment variable is not set. "
+                "Please set it with your OpenAI API key or visit "
+                "https://platform.openai.com/account/api-keys to obtain one."
+            )
         self.api_key = api_key
-        self.openai_client = OpenAI()
+        self.openai_client = self._create_client()
+
+    def _create_client(self) -> OpenAI:
+        """Create OpenAI client with current API key."""
+        return OpenAI()

Line range hint 77-82: Add input validation and improve error handling in LocalWhisper.

The audio data conversion could fail silently with invalid input. Consider adding validation and proper error handling.

     def transcribe(self, data: NDArray[np.int16]) -> str:
+        if not isinstance(data, np.ndarray) or data.dtype != np.int16:
+            raise ValueError(
+                f"Expected np.ndarray with dtype np.int16, got {type(data)} with dtype {data.dtype}"
+            )
+        try:
             result = transcribe(self.whisper, data.astype(np.float32) / 32768.0)
             transcription = result["text"]
             return transcription
+        except Exception as e:
+            raise RuntimeError(f"Transcription failed: {str(e)}") from e
🧹 Nitpick comments (14)
src/rai_asr/rai_asr/models/open_wake_word.py (2)

24-36: Ensure robust error handling for model path validity.
Currently, the constructor downloads models but does not confirm the validity of the provided wake_word_model_path. Consider adding defensive checks to ensure the path is correct or raise a more descriptive error if it fails.

 def __init__(self, wake_word_model_path: str, threshold: float = 0.5):
     super(OpenWakeWord, self).__init__()
     self.model_name = "open_wake_word"
     download_models()
+    if not wake_word_model_path:
+        raise ValueError("Invalid wake_word_model_path provided.")
     self.model = OWWModel(

43-43: Rename unused variable 'key' to '_key' to address static analysis.
Although we reference value, the key variable is unused, triggering a linter warning. This is a minor cleanup.

-for key, value in predictions.items():
+for _key, value in predictions.items():
🧰 Tools
🪛 Ruff (0.8.2)

43-43: Loop control variable key not used within loop body

Rename unused key to _key

(B007)

src/rai_asr/rai_asr/models/open_ai_whisper.py (1)

43-52: Consider streaming audio directly without fully buffering the WAV data in memory.
To better handle large audio inputs, you might want to avoid fully buffering the file in-memory. Streaming or chunk-based uploads could improve efficiency and reduce memory usage, though it’s more complex to implement.

src/rai/rai/agents/voice_agent.py (2)

130-135: Manage thread joining carefully to avoid potential blocking.
Joining threads in the main loop can block if a transcription thread is long-running or stuck. Consider a timed or staged join approach to keep the agent responsive.


178-186: Review nested detection logic for correctness and potential expansions.
The should_record pipeline strictly returns True if any detection model signals. If you plan future expansions (e.g., multiple conditions or aggregated thresholds), consider a more flexible approach for combining detection results.

src/rai/rai/agents/__init__.py (2)

18-18: Add module documentation.
You might consider adding short docstrings or module-level comments explaining the purpose of each imported agent for clarity, especially since this is part of the public interface.


24-24: Ensure consistent usage examples.
Now that VoiceRecognitionAgent is exported publicly, ensure there's a minimal usage snippet in the project's documentation or docstrings. This helps users quickly get started.

src/rai/rai/agents/base.py (2)

23-29: Consider storing or removing extra arguments.
The constructor collects *args, **kwargs but does not use them. If not required, remove them; otherwise, store them in instance attributes to avoid dropping potential configuration.

 def __init__(
     self, connectors: Optional[dict[str, BaseConnector]] = None, *args, **kwargs
 ):
     if connectors is None:
         connectors = {}
     self.connectors: dict[str, BaseConnector] = connectors
+    # Example of storing any additional arguments if needed
+    self.extra_args = args
+    self.extra_kwargs = kwargs

30-32: Add docstring for the run method.
Although this is an abstract method, adding docstrings in the base class clarifies usage, helping implementers.

src/rai/rai/communication/__init__.py (1)

18-23: Document newly imported classes.
Provide a brief reference in your docs (or docstrings) explaining the role of ROS2ARIConnector, ROS2ARIMessage, and AudioInputDeviceConfig, so developers have a clear overview.

src/rai_asr/rai_asr/models/local_whisper.py (1)

36-45: Ensure robust handling for large or streaming inputs.

While the current implementation works for small in-memory arrays, consider whether you need partial or streamed transcription for large audio inputs. Whisper supports segments of audio in certain usage scenarios, which might be beneficial for real-time or significantly sized data.

src/rai_asr/rai_asr/models/silero_vad.py (1)

25-44: Use a named parameter for threshold and add input validation.

  1. Consider making threshold a named parameter (e.g., threshold: float = 0.5) in the constructor signature for clarity.
  2. In addition, ensure that the threshold lies within a valid range (e.g., 0.0–1.0) to prevent unexpected behavior.
src/rai/rai/communication/ros2/connectors.py (1)

26-28: Leverage TypedDict more extensively or consider a dataclass.

ROS2ARIPayload currently only has data: Any. If more fields are planned, consider a dataclass or expanding the TypedDict for better clarity and type-safety.

src/rai_asr/rai_asr/asr_clients.py (1)

27-28: Enhance deprecation warning with migration details.

The warning comment should provide more context about where to find the new implementation in rai_core and when this file will be removed.

-# WARN: This file is going to be removed in favour of rai_asr.models
+# DEPRECATED: This file will be removed in favor of rai.agents.voice_agent.
+# Please migrate to the new implementation in rai_core. This file will be removed
+# when the migration is complete. See issue #309 for more details.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9322016 and 05cae9f.

📒 Files selected for processing (15)
  • .pre-commit-config.yaml (1 hunks)
  • src/rai/rai/agents/__init__.py (1 hunks)
  • src/rai/rai/agents/base.py (1 hunks)
  • src/rai/rai/agents/voice_agent.py (1 hunks)
  • src/rai/rai/communication/__init__.py (2 hunks)
  • src/rai/rai/communication/ros2/connectors.py (2 hunks)
  • src/rai/rai/communication/sound_device_connector.py (1 hunks)
  • src/rai_asr/rai_asr/asr_clients.py (1 hunks)
  • src/rai_asr/rai_asr/models/__init__.py (1 hunks)
  • src/rai_asr/rai_asr/models/base.py (1 hunks)
  • src/rai_asr/rai_asr/models/local_whisper.py (1 hunks)
  • src/rai_asr/rai_asr/models/open_ai_whisper.py (1 hunks)
  • src/rai_asr/rai_asr/models/open_wake_word.py (1 hunks)
  • src/rai_asr/rai_asr/models/silero_vad.py (1 hunks)
  • tests/communication/test_sound_device_connector.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tests/communication/test_sound_device_connector.py
🧰 Additional context used
🪛 Ruff (0.8.2)
src/rai_asr/rai_asr/models/open_wake_word.py

43-43: Loop control variable key not used within loop body

Rename unused key to _key

(B007)

🔇 Additional comments (13)
src/rai_asr/rai_asr/models/open_wake_word.py (2)

1-14: File header and licensing look good.
No issues identified in the license header block.


18-19: Consider conditional imports for openwakeword to reduce dependencies.
This echoes a previous concern about adding import guards to avoid hard dependencies if the model is optional. For instance, you could wrap the import in a try-except block to handle environments where openwakeword is not installed or not needed.

src/rai_asr/rai_asr/models/open_ai_whisper.py (1)

15-35: Verify the presence of the “OPENAI_API_KEY” environment variable early.
You’re handling the missing API key scenario with a ValueError. This is good. Also consider logging a clear explanation for troubleshooting if the variable is missing.

src/rai/rai/agents/voice_agent.py (2)

45-45: Use device name instead of ID for robust identification.
This comment was raised before, recommending strings for device identification due to device ID variability. Consider this approach to minimize user confusion.


97-105: Validate action_data usage.
You’re passing action_data=None to start_action. Confirm if this interface expects a data structure or if it’s purely optional. Properly documenting these arguments helps future maintainers.

src/rai/rai/communication/__init__.py (1)

33-37: Great job bringing new connectors to __all__.
These additions enhance discoverability of the new classes.

src/rai_asr/rai_asr/models/base.py (2)

25-29: Clarify method naming.
Previous feedback questioned the name “detected.” Consider renaming it to something like is_voice_detected or adding a docstring explaining what “detected” returns.


41-41: Confirm the data type usage.
A previous reviewer questioned whether np.int16 is always guaranteed. If you support other audio formats, either convert or validate them.

src/rai_asr/rai_asr/models/local_whisper.py (1)

29-32: Consider allowing keyword arguments for whisper.load_model.

As previously mentioned, allowing users to pass additional configuration parameters (e.g., whisper.load_model(self.model_name, device="cuda", **kwargs)) may provide more flexibility.

src/rai_asr/rai_asr/models/silero_vad.py (1)

51-61: Avoid repeated calls to model loading logic if performance is critical.

The VAD model is loaded once in the constructor, which is good for performance. Just ensure that if you add future functionalities (e.g., hot-swappable models), you handle them carefully to avoid re-initializations.

src/rai/rai/communication/sound_device_connector.py (1)

132-134: Double-check the resampling logic for edge-case input sizes.

When len(indata) is very small (e.g., less than one frame), or if indata is zero-length, resample() might behave unexpectedly. Ensure you handle edge cases gracefully, possibly by skipping or buffering samples.

Would you like a snippet showing a buffer-based approach to handle especially small audio snapshots?

.pre-commit-config.yaml (1)

47-47: LGTM! Necessary change to align flake8 with Black.

The addition of E203 to the ignore list is correct as it prevents conflicts between flake8 and Black's handling of slice notation.

src/rai_asr/rai_asr/models/__init__.py (1)

15-28: Reconsider adding new features to rai_asr.

Since rai_asr is a ROS 2 package that will be deprecated (as mentioned in previous reviews), adding new features here seems counterintuitive. Consider moving these model definitions directly to rai_core as part of the migration.

src/rai/rai/agents/voice_agent.py Show resolved Hide resolved
src/rai/rai/communication/ros2/connectors.py Outdated Show resolved Hide resolved
@maciejmajek maciejmajek merged commit 1dda834 into development Jan 24, 2025
5 checks passed
@maciejmajek maciejmajek deleted the refactor/rai_asr branch January 24, 2025 16:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants