diff --git a/README.md b/README.md index 246a060..2493e33 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,9 @@ audio_capture_event_handler = MyAudioCaptureEventHandler( event_handler=event_handler ) audio_capture = AudioCapture(audio_capture_event_handler, ...) + +audio_player.start() +audio_capture.start() ``` ## Installation diff --git a/samples/async/sample_realtime_ai_with_local_vad.py b/samples/async/sample_realtime_ai_with_local_vad.py index 8930b44..ced579e 100644 --- a/samples/async/sample_realtime_ai_with_local_vad.py +++ b/samples/async/sample_realtime_ai_with_local_vad.py @@ -343,12 +343,16 @@ async def main(): ) logger.info("Recording... Press Ctrl+C to stop.") + audio_player.start() + audio_capture.start() # Keep the loop running while the stream is active await asyncio.Event().wait() # Effectively blocks indefinitely except KeyboardInterrupt: logger.info("Recording stopped by user.") + audio_capture.stop() + audio_player.stop() except Exception as e: logger.error(f"Unexpected error: {e}") finally: diff --git a/samples/sample_realtime_ai_with_keyword_and_vad.py b/samples/sample_realtime_ai_with_keyword_and_vad.py index 3076ba9..36eca7a 100644 --- a/samples/sample_realtime_ai_with_keyword_and_vad.py +++ b/samples/sample_realtime_ai_with_keyword_and_vad.py @@ -3,6 +3,8 @@ import os, json from typing import Any, Dict import threading +import time +from enum import Enum, auto from utils.audio_playback import AudioPlayer from utils.audio_capture import AudioCapture, AudioCaptureEventHandler @@ -31,6 +33,12 @@ logger = logging.getLogger() +class ConversationState(Enum): + IDLE = auto() + KEYWORD_DETECTED = auto() + CONVERSATION_ACTIVE = auto() + + class MyAudioCaptureEventHandler(AudioCaptureEventHandler): def __init__(self, client: RealtimeAIClient, event_handler: "MyRealtimeEventHandler"): """ @@ -41,8 +49,7 @@ def __init__(self, client: RealtimeAIClient, event_handler: "MyRealtimeEventHand """ self._client = client self._event_handler = event_handler - self._keyword_detected = False - self._conversation_active = False + self._state = ConversationState.IDLE self._silence_timeout = 10 # Silence timeout in seconds for rearming keyword detection self._silence_timer = None @@ -52,26 +59,24 @@ def send_audio_data(self, audio_data: bytes): :param audio_data: Raw audio data in bytes. """ - if self._conversation_active: + if self._state == ConversationState.CONVERSATION_ACTIVE: logger.info("Sending audio data to the client.") self._client.send_audio(audio_data) def on_speech_start(self): """ Handles actions to perform when speech starts. - """ logger.info("Local VAD: User speech started") - logger.info(f"on_speech_start: Keyword detected: {self._keyword_detected}, Conversation active: {self._conversation_active}") + logger.info(f"on_speech_start: Current state: {self._state}") - if self._keyword_detected: - self._conversation_active = True - if self._silence_timer: - self._silence_timer.cancel() + if self._state == ConversationState.KEYWORD_DETECTED: + self._set_state(ConversationState.CONVERSATION_ACTIVE) + self._cancel_silence_timer() if (self._client.options.turn_detection is None and self._event_handler.is_audio_playing() and - self._conversation_active): + self._state == ConversationState.CONVERSATION_ACTIVE): logger.info("User started speaking while assistant is responding; interrupting the assistant's response.") self._client.clear_input_audio_buffer() self._client.cancel_response() @@ -82,9 +87,9 @@ def on_speech_end(self): Handles actions to perform when speech ends. """ logger.info("Local VAD: User speech ended") - logger.info(f"on_speech_end: Keyword detected: {self._keyword_detected}, Conversation active: {self._conversation_active}") + logger.info(f"on_speech_end: Current state: {self._state}") - if self._conversation_active and self._client.options.turn_detection is None: + if self._state == ConversationState.CONVERSATION_ACTIVE and self._client.options.turn_detection is None: logger.debug("Using local VAD; requesting the client to generate a response after speech ends.") self._client.generate_response() logger.debug("Conversation is active. Starting silence timer.") @@ -98,16 +103,20 @@ def on_keyword_detected(self, result): """ logger.info(f"Local Keyword: User keyword detected: {result}") self._client.send_text("Hello") - self._keyword_detected = True - self._conversation_active = True + self._set_state(ConversationState.KEYWORD_DETECTED) + self._start_silence_timer() def _start_silence_timer(self): + self._cancel_silence_timer() + self._silence_timer = threading.Timer(self._silence_timeout, self._reset_state_due_to_silence) + self._silence_timer.start() + + def _cancel_silence_timer(self): if self._silence_timer: self._silence_timer.cancel() - self._silence_timer = threading.Timer(self._silence_timeout, self._reset_keyword_detection) - self._silence_timer.start() + self._silence_timer = None - def _reset_keyword_detection(self): + def _reset_state_due_to_silence(self): if self._event_handler.is_audio_playing() or self._event_handler.is_function_processing(): logger.info("Assistant is responding or processing a function. Waiting to reset keyword detection.") self._start_silence_timer() @@ -116,9 +125,13 @@ def _reset_keyword_detection(self): logger.info("Silence timeout reached. Rearming keyword detection.") logger.debug("Clearing input audio buffer.") self._client.clear_input_audio_buffer() + self._set_state(ConversationState.IDLE) - self._keyword_detected = False - self._conversation_active = False + def _set_state(self, new_state: ConversationState): + logger.debug(f"Transitioning from {self._state} to {new_state}") + self._state = new_state + if new_state != ConversationState.CONVERSATION_ACTIVE: + self._cancel_silence_timer() class MyRealtimeEventHandler(RealtimeAIEventHandler): @@ -340,7 +353,7 @@ def main(): client = RealtimeAIClient(options, stream_options, event_handler) event_handler.set_client(client) client.start() - + audio_capture_event_handler = MyAudioCaptureEventHandler( client=client, event_handler=event_handler @@ -367,6 +380,8 @@ def main(): ) logger.info("Recording... Press Ctrl+C to stop.") + audio_player.start() + audio_capture.start() # Loop to ensure keyboard interrupt is caught correctly stop_event = threading.Event() @@ -375,6 +390,9 @@ def main(): stop_event.wait(timeout=0.1) except KeyboardInterrupt: logger.info("Recording stopped by user.") + audio_capture.stop() + audio_player.stop() + if audio_player: audio_player.close() if audio_capture: diff --git a/samples/sample_realtime_ai_with_local_vad.py b/samples/sample_realtime_ai_with_local_vad.py index fe6db18..1ba870d 100644 --- a/samples/sample_realtime_ai_with_local_vad.py +++ b/samples/sample_realtime_ai_with_local_vad.py @@ -296,7 +296,7 @@ def main(): client = RealtimeAIClient(options, stream_options, event_handler) event_handler.set_client(client) client.start() - + audio_capture_event_handler = MyAudioCaptureEventHandler( client=client, event_handler=event_handler @@ -322,6 +322,8 @@ def main(): ) logger.info("Recording... Press Ctrl+C to stop.") + audio_player.start() + audio_capture.start() # Loop to ensure keyboard interrupt is caught correctly stop_event = threading.Event() @@ -330,6 +332,9 @@ def main(): stop_event.wait(timeout=0.1) except KeyboardInterrupt: logger.info("Recording stopped by user.") + audio_capture.stop() + audio_player.stop() + if audio_player: audio_player.close() if audio_capture: diff --git a/samples/utils/audio_capture.py b/samples/utils/audio_capture.py index b97ef68..a5b61dc 100644 --- a/samples/utils/audio_capture.py +++ b/samples/utils/audio_capture.py @@ -97,18 +97,10 @@ def __init__( self.vad = None self.speech_started = False - if self.enable_wave_capture: - try: - self.wave_file = wave.open("microphone_output.wav", "wb") - self.wave_file.setnchannels(self.channels) - self.wave_file.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) - self.wave_file.setframerate(self.sample_rate) - logger.info("Wave file initialized for capture.") - except Exception as e: - logger.error(f"Error opening wave file: {e}") - self.enable_wave_capture = False + self.wave_file = None + self.pyaudio_instance = pyaudio.PyAudio() + self.stream = None - # Initialize VAD if vad_parameters is not None: try: self.vad = VoiceActivityDetector(**vad_parameters) @@ -128,18 +120,47 @@ def __init__( self.keyword_recognizer = AzureKeywordRecognizer( model_file=keyword_model_file, callback=self._on_keyword_detected, - sample_rate=sample_rate, - channels=channels + sample_rate=self.sample_rate, + channels=self.channels ) - self.keyword_recognizer.start_recognition() logger.info("Keyword recognizer initialized.") except Exception as e: logger.error(f"Failed to initialize AzureKeywordRecognizer: {e}") - # Initialize PyAudio for input - self.p = pyaudio.PyAudio() + self.is_running = False + + def start(self): + """ + Starts the audio capture stream and initializes necessary components. + """ + if self.is_running: + logger.warning("AudioCapture is already running.") + return + + if self.enable_wave_capture: + try: + self.wave_file = wave.open("microphone_output.wav", "wb") + self.wave_file.setnchannels(self.channels) + self.wave_file.setsampwidth(self.pyaudio_instance.get_sample_size(FORMAT)) + self.wave_file.setframerate(self.sample_rate) + logger.info("Wave file initialized for capture.") + except Exception as e: + logger.error(f"Error opening wave file: {e}") + self.enable_wave_capture = False + + if self.keyword_recognizer: + try: + self.keyword_recognizer.start_recognition() + logger.info("Keyword recognizer started.") + except Exception as e: + logger.error(f"Failed to start AzureKeywordRecognizer: {e}") + + # ensure the pyaudio instance is initialized + if not self.pyaudio_instance: + self.pyaudio_instance = pyaudio.PyAudio() + try: - self.stream = self.p.open( + self.stream = self.pyaudio_instance.open( format=FORMAT, channels=self.channels, rate=self.sample_rate, @@ -148,11 +169,53 @@ def __init__( stream_callback=self.handle_input_audio ) self.stream.start_stream() - logger.info("AudioCapture initialized and input stream started.") + self.is_running = True + logger.info("Audio stream started.") except Exception as e: logger.error(f"Failed to initialize PyAudio Input Stream: {e}") + self.is_running = False raise + def stop(self, terminate: bool = False): + """ + Stops the audio capture stream and releases all resources. + """ + if not self.is_running: + logger.warning("AudioCapture is already stopped.") + return + + try: + if self.stream is not None: + self.stream.stop_stream() + self.stream.close() + logger.info("Audio stream stopped and closed.") + except Exception as e: + logger.error(f"Error stopping audio stream: {e}") + + if self.keyword_recognizer: + try: + self.keyword_recognizer.stop_recognition() + logger.info("Keyword recognizer stopped.") + except Exception as e: + logger.error(f"Error stopping AzureKeywordRecognizer: {e}") + + if self.enable_wave_capture and self.wave_file: + try: + self.wave_file.close() + logger.info("Wave file saved successfully.") + except Exception as e: + logger.error(f"Error closing wave file: {e}") + + try: + if self.pyaudio_instance is not None and terminate: + self.pyaudio_instance.terminate() + logger.info("PyAudio terminated.") + except Exception as e: + logger.error(f"Error terminating PyAudio: {e}") + + self.is_running = False + logger.info("AudioCapture has been stopped.") + def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status): """ Combined callback function for PyAudio input stream. @@ -166,7 +229,7 @@ def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status) """ if status: logger.warning(f"Input Stream Status: {status}") - + try: audio_data = np.frombuffer(indata, dtype=np.int16).copy() except ValueError as e: @@ -175,7 +238,7 @@ def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status) if self.vad is None: self.event_handler.send_audio_data(indata) - if self.enable_wave_capture: + if self.enable_wave_capture and self.wave_file: try: self.wave_file.writeframes(indata) except Exception as e: @@ -194,10 +257,16 @@ def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status) if is_speech: if not self.speech_started: logger.info("Speech started") - self.buffer_pointer = self._update_buffer(audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size) - current_buffer = self._get_buffer_content(self.audio_buffer, self.buffer_pointer, self.buffer_size).copy() - - fade_length = min(self.cross_fade_samples, len(current_buffer), len(audio_data)) + self.buffer_pointer = self._update_buffer( + audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size + ) + current_buffer = self._get_buffer_content( + self.audio_buffer, self.buffer_pointer, self.buffer_size + ).copy() + + fade_length = min( + self.cross_fade_samples, len(current_buffer), len(audio_data) + ) if fade_length > 0: fade_out = np.linspace(1.0, 0.0, fade_length, dtype=np.float32) fade_in = np.linspace(0.0, 1.0, fade_length, dtype=np.float32) @@ -218,15 +287,14 @@ def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status) logger.info("Sending buffered audio to client via event handler...") self.event_handler.on_speech_start() self.event_handler.send_audio_data(combined_audio.tobytes()) - if self.enable_wave_capture: + if self.enable_wave_capture and self.wave_file: try: self.wave_file.writeframes(combined_audio.tobytes()) except Exception as e: logger.error(f"Error writing to wave file: {e}") else: - logger.info("Sending audio to client via event handler...") self.event_handler.send_audio_data(audio_data.tobytes()) - if self.enable_wave_capture: + if self.enable_wave_capture and self.wave_file: try: self.wave_file.writeframes(audio_data.tobytes()) except Exception as e: @@ -238,7 +306,9 @@ def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status) self.speech_started = False if self.vad: - self.buffer_pointer = self._update_buffer(audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size) + self.buffer_pointer = self._update_buffer( + audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size + ) return (None, pyaudio.paContinue) @@ -304,24 +374,9 @@ def close(self): """ Closes the audio capture stream and the wave file, releasing all resources. """ - try: - if hasattr(self, 'stream'): - self.stream.stop_stream() - self.stream.close() - logger.info("Audio input stream stopped and closed.") + self.stop(terminate=True) + logger.info("AudioCapture resources have been released.") - if hasattr(self, 'p'): - self.p.terminate() - logger.info("PyAudio terminated.") + def __del__(self): + self.close() - if self.enable_wave_capture and hasattr(self, 'wave_file'): - self.wave_file.close() - logger.info("Wave file saved successfully.") - - if self.keyword_recognizer: - self.keyword_recognizer.stop_recognition() - logger.info("Keyword recognizer stopped.") - - logger.info("AudioCapture resources have been released.") - except Exception as e: - logger.error(f"Error closing AudioCapture: {e}") diff --git a/samples/utils/audio_playback.py b/samples/utils/audio_playback.py index d44361c..82f7c92 100644 --- a/samples/utils/audio_playback.py +++ b/samples/utils/audio_playback.py @@ -5,6 +5,7 @@ import threading import time import wave +from typing import Optional # Constants for PyAudio Configuration FORMAT = pyaudio.paInt16 @@ -18,12 +19,22 @@ class AudioPlayer: """Handles audio playback for decoded audio data using PyAudio.""" - def __init__(self, min_buffer_fill=3, max_buffer_size=0, enable_wave_capture=False): + def __init__( + self, + min_buffer_fill=3, + max_buffer_size=0, + enable_wave_capture=False, + output_filename: Optional[str] = None, + output_device_index: Optional[int] = None + ): """ Initializes the AudioPlayer with a pre-fetch buffer threshold. - :param min_buffer_fill: Minimum number of buffers that should be filled before starting playback initially. - :param max_buffer_size: Maximum size of the buffer queue. + :param min_buffer_fill: Minimum number of buffers that should be filled before starting playback. + :param max_buffer_size: Maximum size of the buffer queue. 0 means unlimited. + :param enable_wave_capture: Flag to enable capturing played audio to wave files. + :param output_filename: Filename for the wave capture, defaults to 'playback_output.wav'. + :param output_device_index: Specific output device index to use. None for default. """ self.initial_min_buffer_fill = min_buffer_fill self.min_buffer_fill = min_buffer_fill @@ -31,48 +42,35 @@ def __init__(self, min_buffer_fill=3, max_buffer_size=0, enable_wave_capture=Fal self.pyaudio_instance = pyaudio.PyAudio() self.stream = None self.stop_event = threading.Event() - self.reset_event = threading.Event() self.playback_complete_event = threading.Event() self.buffer_lock = threading.Lock() self.enable_wave_capture = enable_wave_capture + self.output_filename = output_filename or "playback_output.wav" self.wave_file = None self.buffers_played = 0 + self.is_running = False + self.thread = None + self.output_device_index = output_device_index - # Fade-out related attributes - self.fade_out_event = threading.Event() - self.fade_out_duration = 100 # in milliseconds - self.fade_volume = 1.0 - self.fade_step = 0.0 - self.total_fade_steps = 0 + # Lock for thread-safe operations + self.lock = threading.RLock() - self._initialize_wave_file() - self._initialize_stream() - self._start_thread() + # Initialize wave file if capture is enabled + if self.enable_wave_capture: + self._initialize_wave_file() def _initialize_wave_file(self): + """Sets up the wave file for capturing playback if enabled.""" if self.enable_wave_capture: try: - self.wave_file = wave.open("playback_output.wav", "wb") + self.wave_file = wave.open(self.output_filename, "wb") self.wave_file.setnchannels(CHANNELS) self.wave_file.setsampwidth(self.pyaudio_instance.get_sample_size(FORMAT)) self.wave_file.setframerate(RATE) - logger.info("Wave file for playback capture initialized.") + logger.info(f"Wave file '{self.output_filename}' initialized for capture.") except Exception as e: logger.error(f"Error opening wave file for playback capture: {e}") - - def _initialize_stream(self): - """Initializes or reinitializes the PyAudio stream.""" - if self.stream: - self.stream.stop_stream() - self.stream.close() - self.stream = self.pyaudio_instance.open( - format=FORMAT, - channels=CHANNELS, - rate=RATE, - output=True, - frames_per_buffer=FRAMES_PER_BUFFER - ) - logger.info("PyAudio stream initialized.") + self.enable_wave_capture = False def _start_thread(self): """Starts the playback thread.""" @@ -80,105 +78,138 @@ def _start_thread(self): self.thread.start() logger.info("Playback thread started.") - def is_audio_playing(self): - """Checks if audio is currently playing.""" - with self.buffer_lock: - buffer_not_empty = not self.buffer.empty() - is_playing = buffer_not_empty - logger.debug(f"Checking if audio is playing: Buffer not empty = {buffer_not_empty}, " - f"Is playing = {is_playing}") - return is_playing + def start(self): + """ + Starts the audio playback stream and initializes necessary components. + """ + with self.lock: + if self.is_running: + logger.warning("AudioPlayer is already running.") + return + + # ensure the pyaudio instance is initialized + if self.pyaudio_instance is None: + self.pyaudio_instance = pyaudio.PyAudio() + + try: + self.stream = self.pyaudio_instance.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + output=True, + frames_per_buffer=FRAMES_PER_BUFFER, + output_device_index=self.output_device_index + ) + logger.info(f"AudioPlayer started, stream: {self.stream}") + + self.is_running = True + self.stop_event.clear() + self.playback_complete_event.clear() + self._start_thread() + logger.info("AudioPlayer started.") + + except Exception as e: + logger.error(f"Failed to start AudioPlayer: {e}") + self.is_running = False + self.stop_event.set() + self.playback_complete_event.set() + + def stop(self): + """ + Stops the audio playback stream and releases resources. + """ + with self.lock: + if not self.is_running: + logger.warning("AudioPlayer is already stopped.") + return + + self.stop_event.set() + logger.info("Stop event set. Attempting to stop playback thread.") + + # Wait for playback thread to finish + if self.thread and self.thread.is_alive(): + self.playback_complete_event.wait(timeout=5) + self.thread.join(timeout=5) + logger.info("Playback thread joined.") + + # Stop and close the stream + if self.stream is not None: + logger.info(f"Stopping stream: {self.stream}") + try: + if self.stream.is_active(): + self.stream.stop_stream() + self.stream.close() + self.stream = None # Ensure stream reference is cleared + logger.info("PyAudio stream stopped and closed.") + except Exception as e: + logger.error(f"Error stopping PyAudio stream: {e}") + + # Close the wave file if enabled + if self.enable_wave_capture and self.wave_file is not None: + try: + self.wave_file.close() + self.wave_file = None + logger.info(f"Wave file '{self.output_filename}' closed.") + except Exception as e: + logger.error(f"Error closing wave file: {e}") + + self.is_running = False + self.stop_event.clear() + self.playback_complete_event.set() + logger.info("AudioPlayer stopped and resources released.") def playback_loop(self): - """Main playback loop that handles audio streaming.""" + """Main loop that manages the audio playback.""" + logger.debug("Playback loop started.") self.playback_complete_event.clear() self.initial_buffer_fill() while not self.stop_event.is_set(): try: - if self.reset_event.is_set(): - if not self.fade_out_event.is_set(): - logger.debug("Reset event detected; initiating fade-out.") - self._initiate_fade_out() - self.reset_event.clear() - time.sleep(0.01) - continue - - try: - data = self.buffer.get(timeout=0.1) - if data is None: - break - except queue.Empty: - logger.debug("Playback queue empty, waiting for data.") - time.sleep(0.1) - continue - - if self.fade_out_event.is_set() and self.total_fade_steps > 0: - audio_data = np.frombuffer(data, dtype=np.int16).astype(np.float32) - audio_data *= self.fade_volume - self.fade_volume -= self.fade_step - self.total_fade_steps -= 1 - - if self.fade_volume < 0.0: - self.fade_volume = 0.0 - - data = audio_data.astype(np.int16).tobytes() - logger.debug(f"Applying fade-out: Remaining steps={self.total_fade_steps}, Current volume={self.fade_volume:.4f}") - - if self.total_fade_steps <= 0: - logger.debug("Fade-out complete; clearing buffers.") - self.fade_out_event.clear() - self._clear_buffer() - self._reset_playback_state() + data = self.buffer.get(timeout=0.1) + if data is None: + logger.debug("None received from buffer; breaking playback loop.") + break + # Write data to stream self._write_data_to_stream(data) with self.buffer_lock: self.buffers_played += 1 logger.debug(f"Audio played. Buffers played count: {self.buffers_played}") + except queue.Empty: + logger.debug("Playback queue empty; waiting for data.") + continue except Exception as e: logger.error(f"Unexpected error in playback loop: {e}") + break # Exit loop on unexpected errors - logger.info("Playback thread terminated.") + logger.info("Playback loop exiting.") self.playback_complete_event.set() - def _initiate_fade_out(self): - """Initiates the fade-out process.""" - self.fade_out_event.set() - fade_duration_sec = self.fade_out_duration / 1000.0 - self.total_fade_steps = int((fade_duration_sec * RATE) / FRAMES_PER_BUFFER) - if self.total_fade_steps <= 0: - self.total_fade_steps = 1 - self.fade_step = 1.0 / self.total_fade_steps - self.fade_volume = 1.0 - logger.debug(f"Fade-out initiated: Duration={self.fade_out_duration}ms, Total steps={self.total_fade_steps}") - - def _reset_playback_state(self): - """Resets the playback state after fade-out.""" - logger.debug("Resetting playback state.") - with self.buffer_lock: - self.buffers_played = 0 - self.min_buffer_fill = self.initial_min_buffer_fill - self.reset_event.clear() - logger.debug("Playback state has been reset.") - def _write_data_to_stream(self, data: bytes): """Writes audio data to the PyAudio stream and handles wave file capture if enabled.""" try: + if self.stream: + self.stream.write(data, exception_on_underflow=False) + logger.debug("Data written to PyAudio stream.") if self.enable_wave_capture and self.wave_file: self.wave_file.writeframes(data) - self.stream.write(data) + logger.debug("Data written to wave file.") except IOError as e: logger.error(f"I/O error during stream write: {e}") + # Attempt to restart the stream try: - self.stream.stop_stream() - self.stream.start_stream() + if self.stream and not self.stream.is_stopped(): + self.stream.stop_stream() + if self.stream: + self.stream.start_stream() logger.info("PyAudio stream restarted after I/O error.") except Exception as restart_error: logger.error(f"Failed to restart PyAudio stream: {restart_error}") except Exception as e: - logger.error(f"Unexpected error occurred while writing to stream: {e}") + logger.error(f"Unexpected error during stream write: {e}") def initial_buffer_fill(self): """Fills the buffer initially to ensure smooth playback start.""" @@ -187,58 +218,61 @@ def initial_buffer_fill(self): with self.buffer_lock: current_size = self.buffer.qsize() if current_size >= self.min_buffer_fill: + logger.debug("Initial buffer fill complete.") break - time.sleep(0.01) - logger.debug("Initial buffer fill complete.") + time.sleep(0.01) # Sleep briefly to yield control def enqueue_audio_data(self, audio_data: bytes): - """Enqueues audio data into the playback buffer.""" + """Queues data for playback.""" try: - with self.buffer_lock: - self.buffer.put(audio_data, timeout=1) - logger.debug(f"Enqueued audio data. Queue size: {self.buffer.qsize()}") + self.buffer.put_nowait(audio_data) + logger.debug(f"Enqueued audio data. Queue size: {self.buffer.qsize()}") except queue.Full: - logger.warning("Failed to enqueue audio data: Buffer full.") + logger.warning("Queue is full; dropping audio data.") - def _clear_buffer(self): - """Clears all pending audio data from the buffer.""" + def is_audio_playing(self) -> bool: + """Checks if audio is currently playing.""" with self.buffer_lock: - cleared_items = 0 - while not self.buffer.empty(): - try: - self.buffer.get_nowait() - cleared_items += 1 - except queue.Empty: - break - logger.debug(f"Cleared {cleared_items} items from the buffer.") + buffer_not_empty = not self.buffer.empty() + is_playing = buffer_not_empty + logger.debug(f"Checking if audio is playing: Buffer not empty = {buffer_not_empty}, " + f"Is playing = {is_playing}") + return is_playing def drain_and_restart(self): - """Configures the player to initiate a fade-out and reset playback.""" + """Resets the playback state and clears the audio buffer without stopping playback.""" with self.buffer_lock: - logger.debug("Prepare for fade-out and reset.") - self.fade_out_duration = 100 - self.reset_event.set() - logger.info("Configured to reset with fade-out.") + logger.info("Draining and restarting playback state.") + self._clear_buffer() + self.buffers_played = 0 + self.min_buffer_fill = self.initial_min_buffer_fill + logger.debug("Playback state reset and buffer cleared.") + + def _clear_buffer(self): + """Clears all pending audio data from the buffer.""" + cleared_items = 0 + while not self.buffer.empty(): + try: + self.buffer.get_nowait() + cleared_items += 1 + except queue.Empty: + break + logger.debug(f"Cleared {cleared_items} items from the buffer.") def close(self): - """Closes the AudioPlayer, stopping playback and releasing resources.""" + """Ensures resources are released by stopping playback and terminating PyAudio.""" logger.info("Closing AudioPlayer.") - self.stop_event.set() - self.buffer.put(None) - self.playback_complete_event.wait(timeout=5) - self.thread.join(timeout=5) - if self.stream and self.stream.is_active(): - self.stream.stop_stream() - logger.debug("PyAudio stream stopped.") - if self.stream: - self.stream.close() - logger.debug("PyAudio stream closed.") - self.pyaudio_instance.terminate() - logger.debug("PyAudio terminated.") - if self.enable_wave_capture and self.wave_file: + self.stop() + # Terminate PyAudio instance when closing + if self.pyaudio_instance is not None: try: - self.wave_file.close() - logger.info("Playback wave file saved successfully.") + self.pyaudio_instance.terminate() + self.pyaudio_instance = None + logger.info("PyAudio instance terminated.") except Exception as e: - logger.error(f"Error closing wave file for playback: {e}") - logger.info("AudioPlayer stopped and resources released.") \ No newline at end of file + logger.error(f"Error terminating PyAudio instance: {e}") + logger.info("AudioPlayer resources have been released.") + + def __del__(self): + """Ensures that resources are released upon deletion.""" + self.close() \ No newline at end of file diff --git a/setup.py b/setup.py index 768f3d2..847f7c6 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="realtime-ai", - version="0.1.0", + version="0.1.1", description="Python SDK for real-time audio processing with OpenAI's Realtime REST API.", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/src/realtime_ai/aio/audio_stream_manager.py b/src/realtime_ai/aio/audio_stream_manager.py index 2039213..0ca36f9 100644 --- a/src/realtime_ai/aio/audio_stream_manager.py +++ b/src/realtime_ai/aio/audio_stream_manager.py @@ -39,7 +39,7 @@ async def write_audio_buffer(self, audio_data: bytes): if not self.is_streaming: self._start_stream() logger.info("Enqueuing audio data for streaming.") - await self.audio_queue.put(audio_data) + await self.audio_queue.put_nowait(audio_data) logger.info("Audio data enqueued for streaming.") async def _stream_audio(self): diff --git a/src/realtime_ai/aio/realtime_ai_client.py b/src/realtime_ai/aio/realtime_ai_client.py index 050449c..973907a 100644 --- a/src/realtime_ai/aio/realtime_ai_client.py +++ b/src/realtime_ai/aio/realtime_ai_client.py @@ -62,7 +62,7 @@ async def send_audio(self, audio_data: bytes): logger.info("RealtimeAIClient: Queuing audio data for streaming.") await self.audio_stream_manager.write_audio_buffer(audio_data) - async def send_text(self, text: str): + async def send_text(self, text: str, role: str = "user", generate_response: bool = True): """Sends text input to the service manager. """ event = { @@ -70,10 +70,10 @@ async def send_text(self, text: str): "type": "conversation.item.create", "item": { "type": "message", - "role": "user", + "role": role, "content": [ { - "type": "input_text", + "type": "text" if role == "assistant" else "input_text", "text": text } ] @@ -81,18 +81,43 @@ async def send_text(self, text: str): } await self.service_manager.send_event(event) logger.info("RealtimeAIClient: Sent text input to server.") - # Using server VAD; requesting the client to generate a response after text input. - if self._options.turn_detection: - await self.generate_response() - async def generate_response(self): - """Sends a response.create event to generate a response.""" - logger.info("RealtimeAIClient: Generating response.") - commit_event = { + # Generate a response if required + if generate_response: + await self.generate_response(commit_audio_buffer=False) + + async def update_session(self, options: RealtimeAIOptions): + """Updates the session configuration with the provided options.""" + event = { "event_id": self.service_manager._generate_event_id(), - "type": "input_audio_buffer.commit" + "type": "session.update", + "session": { + "modalities": options.modalities, + "instructions": options.instructions, + "voice": options.voice, + "input_audio_format": options.input_audio_format, + "output_audio_format": options.output_audio_format, + "input_audio_transcription": { + "model": options.input_audio_transcription_model + }, + "turn_detection": options.turn_detection, + "tools": options.tools, + "tool_choice": options.tool_choice, + "temperature": options.temperature + } } - await self.service_manager.send_event(commit_event) + await self.service_manager.send_event(event) + logger.info("RealtimeAIClient: Sent session update to server.") + + async def generate_response(self, commit_audio_buffer: bool = True): + """Sends a response.create event to generate a response.""" + logger.info("RealtimeAIClient: Generating response.") + if commit_audio_buffer: + commit_event = { + "event_id": self.service_manager._generate_event_id(), + "type": "input_audio_buffer.commit" + } + await self.service_manager.send_event(commit_event) response_create_event = { "event_id": self.service_manager._generate_event_id(), diff --git a/src/realtime_ai/aio/realtime_ai_service_manager.py b/src/realtime_ai/aio/realtime_ai_service_manager.py index 1a72cff..bf2258b 100644 --- a/src/realtime_ai/aio/realtime_ai_service_manager.py +++ b/src/realtime_ai/aio/realtime_ai_service_manager.py @@ -76,6 +76,7 @@ async def connect(self): async def disconnect(self): try: + await self.event_queue.put(None) # Signal the event loop to stop await self.websocket_manager.disconnect() except asyncio.CancelledError: logger.info("RealtimeAIServiceManager: Disconnect was cancelled.") @@ -190,7 +191,3 @@ async def get_next_event(self) -> Optional[EventBase]: def _generate_event_id(self) -> str: return f"event_{uuid.uuid4()}" - - def enqueue_event(self, event: dict): - self.event_queue.put_nowait(event) - logger.debug(f"RealtimeAIServiceManager: Event enqueued: {event.get('type')}") diff --git a/src/realtime_ai/audio_stream_manager.py b/src/realtime_ai/audio_stream_manager.py index 73f4e38..b282209 100644 --- a/src/realtime_ai/audio_stream_manager.py +++ b/src/realtime_ai/audio_stream_manager.py @@ -41,7 +41,7 @@ def write_audio_buffer_sync(self, audio_data: bytes): if not self.is_streaming: self.start_stream() logger.info("Enqueuing audio data for streaming.") - self.audio_queue.put(audio_data) + self.audio_queue.put_nowait(audio_data) logger.info("Audio data enqueued for streaming.") def _stream_audio(self): diff --git a/src/realtime_ai/realtime_ai_client.py b/src/realtime_ai/realtime_ai_client.py index 6895736..00dad2f 100644 --- a/src/realtime_ai/realtime_ai_client.py +++ b/src/realtime_ai/realtime_ai_client.py @@ -22,48 +22,88 @@ def __init__(self, options: RealtimeAIOptions, stream_options: AudioStreamOption self.audio_stream_manager = AudioStreamManager(stream_options, self.service_manager) self.event_handler = event_handler self.is_running = False - self.event_queue = queue.Queue() + self._lock = threading.Lock() - # Thread for consuming events - self._consume_thread = threading.Thread(target=self._consume_events, daemon=True) - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) + # Initialize the consume thread and executor as None + self._consume_thread = None + self.executor = None + self._stop_event = threading.Event() def start(self): """Starts the RealtimeAIClient.""" - self.is_running = True - try: - self.service_manager.connect() # Connect to the service - logger.info("RealtimeAIClient: Client started.") - self._consume_thread.start() # Start event consumption thread - except Exception as e: - logger.error(f"RealtimeAIClient: Error during client start: {e}") - - def stop(self): + with self._lock: + if self.is_running: + logger.warning("RealtimeAIClient: Client is already running.") + return + + self.is_running = True + try: + self.service_manager.connect() # Connect to the service + logger.info("RealtimeAIClient: Client started.") + + # Initialize and start the ThreadPoolExecutor here + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) + logger.info("RealtimeAIClient: ThreadPoolExecutor initialized.") + + # Initialize and start the consume thread + self._consume_thread = threading.Thread(target=self._consume_events, daemon=True, name="RealtimeAIClient_ConsumeThread") + self._consume_thread.start() + logger.info("RealtimeAIClient: Event consumption thread started.") + except Exception as e: + self.is_running = False + logger.error(f"RealtimeAIClient: Error during client start: {e}") + + def stop(self, timeout: float = 5.0): """Stops the RealtimeAIClient gracefully.""" - self.is_running = False - self.audio_stream_manager.stop_stream() - self.service_manager.disconnect() - self._consume_thread.join(timeout=5) - self.executor.shutdown(wait=True) # Gracefully shut down the executor - logger.info("RealtimeAIClient: Services stopped.") + with self._lock: + if not self.is_running: + logger.warning("RealtimeAIClient: Client is already stopped.") + return + + self.is_running = False + + # Signal stop event + self._stop_event.set() + + try: + self.audio_stream_manager.stop_stream() + self.service_manager.disconnect() + + if self._consume_thread is not None: + + # Attempt to join the consume thread within the timeout + self._consume_thread.join(timeout=timeout) + if self._consume_thread.is_alive(): + logger.warning("RealtimeAIClient: Consume thread did not terminate within the timeout.") + else: + logger.info("RealtimeAIClient: Consume thread terminated.") + self._consume_thread = None + + if self.executor is not None: + self.executor.shutdown(wait=True) + logger.info("RealtimeAIClient: ThreadPoolExecutor shut down.") + self.executor = None + + logger.info("RealtimeAIClient: Services stopped.") + except Exception as e: + logger.error(f"RealtimeAIClient: Error during client stop: {e}") def send_audio(self, audio_data: bytes): """Sends audio data to the audio stream manager for processing.""" logger.info("RealtimeAIClient: Queuing audio data for streaming.") - self.audio_stream_manager.write_audio_buffer_sync(audio_data) # Ensure this is a sync method in the audio_stream_manager + self.audio_stream_manager.write_audio_buffer_sync(audio_data) # Ensure this is a sync method - def send_text(self, text: str): - """Sends text input to the service manager. - """ + def send_text(self, text: str, role: str = "user", generate_response: bool = True): + """Sends text input to the service manager.""" event = { "event_id": self.service_manager._generate_event_id(), "type": "conversation.item.create", "item": { "type": "message", - "role": "user", + "role": role, "content": [ { - "type": "input_text", + "type": "text" if role == "assistant" else "input_text", "text": text } ] @@ -71,17 +111,42 @@ def send_text(self, text: str): } self._send_event_to_manager(event) logger.info("RealtimeAIClient: Sent text input to server.") - # Using server VAD; requesting the client to generate a response after text input. - if self._options.turn_detection: - self.generate_response() - def generate_response(self): + # Generate a response if required + if generate_response: + self.generate_response(commit_audio_buffer=False) + + def update_session(self, options: RealtimeAIOptions): + """Updates the session configuration with the provided options.""" + event = { + "event_id": self.service_manager._generate_event_id(), + "type": "session.update", + "session": { + "modalities": options.modalities, + "instructions": options.instructions, + "voice": options.voice, + "input_audio_format": options.input_audio_format, + "output_audio_format": options.output_audio_format, + "input_audio_transcription": { + "model": options.input_audio_transcription_model + }, + "turn_detection": options.turn_detection, + "tools": options.tools, + "tool_choice": options.tool_choice, + "temperature": options.temperature + } + } + self._send_event_to_manager(event) + logger.info("RealtimeAIClient: Sent session update to server.") + + def generate_response(self, commit_audio_buffer: bool = True): """Sends a response.create event to generate a response.""" logger.info("RealtimeAIClient: Generating response.") - self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), - "type": "input_audio_buffer.commit", - }) + if commit_audio_buffer: + self._send_event_to_manager({ + "event_id": self.service_manager._generate_event_id(), + "type": "input_audio_buffer.commit", + }) self._send_event_to_manager({ "event_id": self.service_manager._generate_event_id(), @@ -113,6 +178,7 @@ def truncate_response(self, item_id: str, content_index: int, audio_end_ms: int) logger.info("Client: Sent conversation.item.truncate event to server.") def clear_input_audio_buffer(self): + """Sends an input_audio_buffer.clear event to the server.""" self._send_event_to_manager({ "event_id": self.service_manager._generate_event_id(), "type": "input_audio_buffer.clear" @@ -124,11 +190,8 @@ def generate_response_from_function_call(self, call_id: str, function_output: st Sends a conversation.item.create message as a function call output and optionally triggers a model response. :param call_id: The ID of the function call. - :param name: The name of the function being called. - :param arguments: The arguments used for the function call, in stringified JSON. :param function_output: The output of the function call. """ - # Create the function call output event item_create_event = { "event_id": self.service_manager._generate_event_id(), @@ -145,6 +208,7 @@ def generate_response_from_function_call(self, call_id: str, function_output: st self._send_event_to_manager(item_create_event) logger.info("Function call output event sent.") + # Optionally trigger a response self._send_event_to_manager({ "event_id": self.service_manager._generate_event_id(), "type": "response.create", @@ -153,16 +217,25 @@ def generate_response_from_function_call(self, call_id: str, function_output: st def _consume_events(self): """Consume events from the service manager.""" - while self.is_running: + logger.info("Consume thread: Started consuming events.") + while not self._stop_event.is_set(): try: event = self.service_manager.get_next_event() - if event: + if event is None: + logger.info("Consume thread: Received sentinel, exiting.") + break + + if self.executor is not None: self.executor.submit(self._handle_event, event) else: - time.sleep(0.05) + logger.warning("RealtimeAIClient: Executor is not available or shutting down. Event cannot be handled.") + time.sleep(0.05) + except queue.Empty: + continue except Exception as e: logger.error(f"RealtimeAIClient: Error in consume_events: {e}") break + logger.info("Consume thread: Stopped consuming events.") def _handle_event(self, event: EventBase): """Handles the received event based on its type using the event handler.""" @@ -184,4 +257,9 @@ def _send_event_to_manager(self, event): @property def options(self): - return self._options \ No newline at end of file + return self._options + + # Optional: Ensure that threads are cleaned up if the object is deleted while running + def __del__(self): + if self.is_running: + self.stop() \ No newline at end of file diff --git a/src/realtime_ai/realtime_ai_service_manager.py b/src/realtime_ai/realtime_ai_service_manager.py index 51379fc..007811b 100644 --- a/src/realtime_ai/realtime_ai_service_manager.py +++ b/src/realtime_ai/realtime_ai_service_manager.py @@ -80,6 +80,7 @@ def connect(self): def disconnect(self): try: + self.event_queue.put(None) # Signal the event loop to stop self.websocket_manager.disconnect() self.is_connected = False logger.warning("RealtimeAIServiceManager: WebSocket disconnection started.") @@ -109,7 +110,7 @@ def on_message_received(self, message: str): json_object = json.loads(message) event = self.parse_realtime_event(json_object) if event: - self.event_queue.put(event) + self.event_queue.put_nowait(event) logger.debug(f"RealtimeAIServiceManager: Event queued: {event.type}") except json.JSONDecodeError as e: logger.error(f"RealtimeAIServiceManager: JSON parse error: {e}") @@ -186,11 +187,7 @@ def get_next_event(self, timeout=5.0) -> Optional[EventBase]: logger.info("RealtimeAIServiceManager: Waiting for next event...") return self.event_queue.get(timeout=timeout) except queue.Empty: - return None + raise def _generate_event_id(self) -> str: return f"event_{uuid.uuid4()}" - - def enqueue_event(self, event: dict): - self.event_queue.put_nowait(event) - logger.debug(f"RealtimeAIServiceManager: Event enqueued: {event.get('type')}")