From 0831a9494aa3cce418954c31c7590143898cdf1b Mon Sep 17 00:00:00 2001 From: jhakulin Date: Fri, 8 Nov 2024 14:47:37 -0800 Subject: [PATCH] update samples --- .../sample_realtime_ai_with_local_vad.py | 118 +++++------ ...sample_realtime_ai_with_keyword_and_vad.py | 116 +++++------ samples/sample_realtime_ai_with_local_vad.py | 92 ++++----- samples/user_functions.py | 187 +++++++++++++++--- samples/utils/audio_capture.py | 158 +++++++++------ samples/utils/azure_keyword_recognizer.py | 61 +++++- src/realtime_ai/aio/realtime_ai_client.py | 23 +++ src/realtime_ai/realtime_ai_client.py | 3 + 8 files changed, 502 insertions(+), 256 deletions(-) diff --git a/samples/async/sample_realtime_ai_with_local_vad.py b/samples/async/sample_realtime_ai_with_local_vad.py index f2e5d24..8930b44 100644 --- a/samples/async/sample_realtime_ai_with_local_vad.py +++ b/samples/async/sample_realtime_ai_with_local_vad.py @@ -21,13 +21,13 @@ # Configure logging logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] # Streaming logs to the console ) # Specific loggers for mentioned packages -logging.getLogger("utils.audio_playback").setLevel(logging.DEBUG) +logging.getLogger("utils.audio_playback").setLevel(logging.ERROR) logging.getLogger("utils.audio_capture").setLevel(logging.ERROR) logging.getLogger("utils.vad").setLevel(logging.ERROR) logging.getLogger("realtime_ai").setLevel(logging.ERROR) @@ -46,10 +46,9 @@ def __init__(self, client: RealtimeAIClient, event_handler: "MyRealtimeEventHand :param event_handler: Instance of MyRealtimeEventHandler :param event_loop: The asyncio event loop. """ - self.client = client - self.event_handler = event_handler - self.event_loop = event_loop - self.cancelled = False + self._client = client + self._event_handler = event_handler + self._event_loop = event_loop def send_audio_data(self, audio_data: bytes): """ @@ -57,43 +56,33 @@ def send_audio_data(self, audio_data: bytes): :param audio_data: Raw audio data in bytes. """ - logger.info("Sending audio data to the client.") - asyncio.run_coroutine_threadsafe(self.client.send_audio(audio_data), self.event_loop) + logger.debug("Sending audio data to the client.") + asyncio.run_coroutine_threadsafe(self._client.send_audio(audio_data), self._event_loop) def on_speech_start(self): """ Handles actions to perform when speech starts. - """ - logger.info("Speech has started.") - if self.event_handler.is_audio_playing(): - logger.info(f"User started speaking while audio is playing.") - - logger.info("Clearing input audio buffer.") - asyncio.run_coroutine_threadsafe(self.client.clear_input_audio_buffer(), self.event_loop) - - logger.info("Cancelling response.") - asyncio.run_coroutine_threadsafe(self.client.cancel_response(), self.event_loop) - self.cancelled = True - - #current_item_id = self.event_handler.get_current_conversation_item_id() - #current_audio_content_index = self.event_handler.get_current_audio_content_id() - #logger.info(f"Truncate the current audio, current item ID: {current_item_id}, current audio content index: {current_audio_content_index}") - #asyncio.run_coroutine_threadsafe(self.client.truncate_response(item_id=current_item_id, content_index=current_audio_content_index, audio_end_ms=1000), self.event_loop) - - # Restart the audio player - self.event_handler.audio_player.drain_and_restart() - else: - logger.info("Assistant is not speaking, cancelling response is not required.") - self.cancelled = False + logger.info("Local VAD: User speech started") + if (self._client.options.turn_detection is None and + self._event_handler.is_audio_playing()): + logger.info("User started speaking while assistant is responding; interrupting the assistant's response.") + asyncio.run_coroutine_threadsafe(self._client.clear_input_audio_buffer(), self._event_loop) + asyncio.run_coroutine_threadsafe(self._client.cancel_response(), self._event_loop) + self._event_handler.audio_player.drain_and_restart() def on_speech_end(self): """ Handles actions to perform when speech ends. """ - logger.info("Speech has ended") - logger.info("Requesting the client to generate a response.") - asyncio.run_coroutine_threadsafe(self.client.generate_response(), self.event_loop) + logger.info("Local VAD: User speech ended") + + if self._client.options.turn_detection is None: + logger.debug("Using local VAD; requesting the client to generate a response after speech ends.") + asyncio.run_coroutine_threadsafe(self._client.generate_response(), self._event_loop) + + def on_keyword_detected(self, result): + pass class MyRealtimeEventHandler(RealtimeAIEventHandler): @@ -106,6 +95,7 @@ def __init__(self, audio_player: AudioPlayer, functions: FunctionTool): self._current_audio_content_index = None self._call_id_to_function_name = {} self._functions = functions + self._function_processing = False @property def audio_player(self): @@ -120,6 +110,9 @@ def get_current_audio_content_id(self): def is_audio_playing(self): return self._audio_player.is_audio_playing() + def is_function_processing(self): + return self._function_processing + def set_client(self, client: RealtimeAIClient): self._client = client @@ -127,53 +120,58 @@ async def on_error(self, event: ErrorEvent) -> None: logger.error(f"Error occurred: {event.error.message}") async def on_input_audio_buffer_speech_stopped(self, event: InputAudioBufferSpeechStopped) -> None: - logger.info(f"Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}") + logger.info(f"Server VAD: Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}") async def on_input_audio_buffer_committed(self, event: InputAudioBufferCommitted) -> None: - logger.info(f"Audio Buffer Committed: {event.item_id}") + logger.debug(f"Audio Buffer Committed: {event.item_id}") async def on_conversation_item_created(self, event: ConversationItemCreated) -> None: - logger.info(f"New Conversation Item: {event.item}") + logger.debug(f"New Conversation Item: {event.item}") async def on_response_created(self, event: ResponseCreated) -> None: - logger.info(f"Response Created: {event.response}") + logger.debug(f"Response Created: {event.response}") async def on_response_content_part_added(self, event: ResponseContentPartAdded) -> None: - logger.info(f"New Part Added: {event.part}") + logger.debug(f"New Part Added: {event.part}") async def on_response_audio_delta(self, event: ResponseAudioDelta) -> None: - logger.info(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}") + logger.debug(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}") self._current_item_id = event.item_id self._current_audio_content_index = event.content_index self.handle_audio_delta(event) async def on_response_audio_transcript_delta(self, event: ResponseAudioTranscriptDelta) -> None: - logger.info(f"Transcript Delta: {event.delta}") + logger.info(f"Assistant transcription delta: {event.delta}") async def on_rate_limits_updated(self, event: RateLimitsUpdated) -> None: for rate in event.rate_limits: - logger.info(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}") + logger.debug(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}") async def on_conversation_item_input_audio_transcription_completed(self, event: ConversationItemInputAudioTranscriptionCompleted) -> None: - logger.info(f"Transcription completed for item {event.item_id}: {event.transcript}") + logger.info(f"User transcription complete: {event.transcript}") async def on_response_audio_done(self, event: ResponseAudioDone) -> None: - logger.info(f"Audio done for response ID {event.response_id}, item ID {event.item_id}") + logger.debug(f"Audio done for response ID {event.response_id}, item ID {event.item_id}") async def on_response_audio_transcript_done(self, event: ResponseAudioTranscriptDone) -> None: - logger.info(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}") + logger.debug(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}") async def on_response_content_part_done(self, event: ResponseContentPartDone) -> None: part_type = event.part.get("type") part_text = event.part.get("text", "") - logger.info(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}") + logger.debug(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}") async def on_response_output_item_done(self, event: ResponseOutputItemDone) -> None: item_content = event.item.get("content", []) - logger.info(f"Output item done for response ID {event.response_id} with content: {item_content}") + if item_content: + for item in item_content: + if item.get("type") == "audio": + transcript = item.get("transcript") + if transcript: + logger.info(f"Assistant transcription complete: {transcript}") async def on_response_done(self, event: ResponseDone) -> None: - logger.info(f"Response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'") + logger.debug(f"Assistant's response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'") async def on_session_created(self, event: SessionCreated) -> None: logger.info(f"Session created: {event.session}") @@ -182,15 +180,17 @@ async def on_session_updated(self, event: SessionUpdated) -> None: logger.info(f"Session updated: {event.session}") async def on_input_audio_buffer_speech_started(self, event: InputAudioBufferSpeechStarted) -> None: - logger.info(f"Speech started at {event.audio_start_ms}ms for item ID {event.item_id}") + logger.info(f"Server VAD: User speech started at {event.audio_start_ms}ms for item ID {event.item_id}") + if self._client.options.turn_detection is not None: + await self._client.clear_input_audio_buffer() + await self._client.cancel_response() + await asyncio.threads.to_thread(self._audio_player.drain_and_restart) async def on_response_output_item_added(self, event: ResponseOutputItemAdded) -> None: - logger.info(f"Output item added for response ID {event.response_id} with item: {event.item}") - + logger.debug(f"Output item added for response ID {event.response_id} with item: {event.item}") if event.item.get("type") == "function_call": call_id = event.item.get("call_id") function_name = event.item.get("name") - if call_id and function_name: # Properly acquire the lock with 'await' and spread the usage over two lines await self._lock.acquire() # Wait until the lock is available, then acquire it @@ -204,11 +204,9 @@ async def on_response_output_item_added(self, event: ResponseOutputItemAdded) -> logger.warning("Function call item missing 'call_id' or 'name' fields.") async def on_response_function_call_arguments_delta(self, event: ResponseFunctionCallArgumentsDelta) -> None: - logger.info(f"Function call arguments delta for call ID {event.call_id}: {event.delta}") + logger.debug(f"Function call arguments delta for call ID {event.call_id}: {event.delta}") async def on_response_function_call_arguments_done(self, event: ResponseFunctionCallArgumentsDone) -> None: - logger.info(f"Function call arguments done for call ID {event.call_id} with arguments: {event.arguments}") - call_id = event.call_id arguments_str = event.arguments @@ -225,14 +223,16 @@ async def on_response_function_call_arguments_done(self, event: ResponseFunction return try: + self._function_processing = True + logger.info(f"Executing function '{function_name}' with arguments: {arguments_str} for call ID {call_id}") function_output = await asyncio.threads.to_thread(self._functions.execute, function_name, arguments_str) logger.info(f"Function output for call ID {call_id}: {function_output}") - - # Assuming generate_response_from_function_call is an async method await self._client.generate_response_from_function_call(call_id, function_output) except json.JSONDecodeError as e: logger.error(f"Failed to parse arguments for call ID {call_id}: {e}") return + finally: + self._function_processing = False def on_unhandled_event(self, event_type: str, event_data: Dict[str, Any]): logger.warning(f"Unhandled Event: {event_type} - {event_data}") @@ -291,7 +291,7 @@ async def main(): model="gpt-4o-realtime-preview-2024-10-01", modalities=["audio", "text"], instructions="You are a helpful assistant. Respond concisely. If user asks to tell story, tell story very shortly.", - turn_detection=get_vad_configuration(), + turn_detection=get_vad_configuration(use_server_vad=False), tools=functions.definitions, tool_choice="auto", temperature=0.8, @@ -315,7 +315,7 @@ async def main(): event_handler.set_client(client) await client.start() - loop = asyncio.get_running_loop() # Get the current event loop + loop = asyncio.get_running_loop() audio_capture_event_handler = MyAudioCaptureEventHandler( client=client, @@ -334,7 +334,7 @@ async def main(): vad_parameters={ "sample_rate": 24000, "chunk_size": 1024, - "window_duration": 1.0, + "window_duration": 1.5, "silence_ratio": 1.5, "min_speech_duration": 0.3, "min_silence_duration": 1.0 diff --git a/samples/sample_realtime_ai_with_keyword_and_vad.py b/samples/sample_realtime_ai_with_keyword_and_vad.py index d8bf0d1..3076ba9 100644 --- a/samples/sample_realtime_ai_with_keyword_and_vad.py +++ b/samples/sample_realtime_ai_with_keyword_and_vad.py @@ -16,13 +16,13 @@ # Configure logging logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] # Streaming logs to the console ) # Specific loggers for mentioned packages -logging.getLogger("utils.audio_playback").setLevel(logging.DEBUG) +logging.getLogger("utils.audio_playback").setLevel(logging.ERROR) logging.getLogger("utils.audio_capture").setLevel(logging.ERROR) logging.getLogger("utils.vad").setLevel(logging.ERROR) logging.getLogger("realtime_ai").setLevel(logging.ERROR) @@ -37,14 +37,13 @@ def __init__(self, client: RealtimeAIClient, event_handler: "MyRealtimeEventHand Initializes the event handler. :param client: Instance of RealtimeClient. - :param event_handler: Instance of MyRealtimeEventHandler - :param event_loop: The asyncio event loop. + :param event_handler: Instance of MyRealtimeEventHandler. """ self._client = client self._event_handler = event_handler self._keyword_detected = False self._conversation_active = False - self._silence_timeout = 10 # Configurable silence timeout in seconds for rearming keyword detection + self._silence_timeout = 10 # Silence timeout in seconds for rearming keyword detection self._silence_timer = None def send_audio_data(self, audio_data: bytes): @@ -62,39 +61,33 @@ def on_speech_start(self): Handles actions to perform when speech starts. """ - logger.info(f"Speech has started, keyword detected: {self._keyword_detected}, conversation active: {self._conversation_active}") + logger.info("Local VAD: User speech started") + logger.info(f"on_speech_start: Keyword detected: {self._keyword_detected}, Conversation active: {self._conversation_active}") if self._keyword_detected: self._conversation_active = True if self._silence_timer: self._silence_timer.cancel() - if self._event_handler.is_audio_playing(): - logger.info(f"User started speaking while audio is playing.") - - logger.info("Clearing input audio buffer.") + if (self._client.options.turn_detection is None and + self._event_handler.is_audio_playing() and + self._conversation_active): + logger.info("User started speaking while assistant is responding; interrupting the assistant's response.") self._client.clear_input_audio_buffer() - - logger.info("Cancelling response.") self._client.cancel_response() - - #current_item_id = self._event_handler.get_current_conversation_item_id() - #current_audio_content_index = self._event_handler.get_current_audio_content_id() - #logger.info(f"Truncate the current audio, current item ID: {current_item_id}, current audio content index: {current_audio_content_index}") - #self._client.truncate_response(item_id=current_item_id, content_index=current_audio_content_index, audio_end_ms=1000) - - # Restart the audio player self._event_handler.audio_player.drain_and_restart() def on_speech_end(self): """ Handles actions to perform when speech ends. """ - logger.info(f"Speech has ended, keyword detected: {self._keyword_detected}, conversation active: {self._conversation_active}") + logger.info("Local VAD: User speech ended") + logger.info(f"on_speech_end: Keyword detected: {self._keyword_detected}, Conversation active: {self._conversation_active}") - if self._conversation_active: - logger.info("Requesting the client to generate a response.") + if self._conversation_active and self._client.options.turn_detection is None: + logger.debug("Using local VAD; requesting the client to generate a response after speech ends.") self._client.generate_response() + logger.debug("Conversation is active. Starting silence timer.") self._start_silence_timer() def on_keyword_detected(self, result): @@ -103,7 +96,7 @@ def on_keyword_detected(self, result): :param result: The recognition result containing details about the detected keyword. """ - logger.info(f"Keyword detected: {result}") + logger.info(f"Local Keyword: User keyword detected: {result}") self._client.send_text("Hello") self._keyword_detected = True self._conversation_active = True @@ -115,16 +108,13 @@ def _start_silence_timer(self): self._silence_timer.start() def _reset_keyword_detection(self): - # if assistant is speaking, wait for it to finish - if self._event_handler.is_audio_playing(): - logger.info("Audio is playing. Waiting for audio to finish before resetting keyword detection.") + if self._event_handler.is_audio_playing() or self._event_handler.is_function_processing(): + logger.info("Assistant is responding or processing a function. Waiting to reset keyword detection.") self._start_silence_timer() return - logger.info("Silence timeout reached. Resetting keyword detection.") - - # Clear the input audio buffer on the server - logger.info("Clearing input audio buffer.") + logger.info("Silence timeout reached. Rearming keyword detection.") + logger.debug("Clearing input audio buffer.") self._client.clear_input_audio_buffer() self._keyword_detected = False @@ -141,6 +131,7 @@ def __init__(self, audio_player: AudioPlayer, functions: FunctionTool): self._call_id_to_function_name = {} self._lock = threading.Lock() self._client = None + self._function_processing = False @property def audio_player(self): @@ -155,6 +146,9 @@ def get_current_audio_content_id(self): def is_audio_playing(self): return self._audio_player.is_audio_playing() + def is_function_processing(self): + return self._function_processing + def set_client(self, client: RealtimeAIClient): self._client = client @@ -162,53 +156,58 @@ def on_error(self, event: ErrorEvent): logger.error(f"Error occurred: {event.error.message}") def on_input_audio_buffer_speech_stopped(self, event: InputAudioBufferSpeechStopped): - logger.info(f"Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}") + logger.info(f"Server VAD: Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}") def on_input_audio_buffer_committed(self, event: InputAudioBufferCommitted): - logger.info(f"Audio Buffer Committed: {event.item_id}") + logger.debug(f"Audio Buffer Committed: {event.item_id}") def on_conversation_item_created(self, event: ConversationItemCreated): - logger.info(f"New Conversation Item: {event.item}") + logger.debug(f"New Conversation Item: {event.item}") def on_response_created(self, event: ResponseCreated): - logger.info(f"Response Created: {event.response}") + logger.debug(f"Response Created: {event.response}") def on_response_content_part_added(self, event: ResponseContentPartAdded): - logger.info(f"New Part Added: {event.part}") + logger.debug(f"New Part Added: {event.part}") def on_response_audio_delta(self, event: ResponseAudioDelta): - logger.info(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}") + logger.debug(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}") self._current_item_id = event.item_id self._current_audio_content_index = event.content_index self.handle_audio_delta(event) def on_response_audio_transcript_delta(self, event: ResponseAudioTranscriptDelta): - logger.info(f"Transcript Delta: {event.delta}") + logger.info(f"Assistant transcription delta: {event.delta}") def on_rate_limits_updated(self, event: RateLimitsUpdated): for rate in event.rate_limits: - logger.info(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}") + logger.debug(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}") def on_conversation_item_input_audio_transcription_completed(self, event: ConversationItemInputAudioTranscriptionCompleted): - logger.info(f"Transcription completed for item {event.item_id}: {event.transcript}") + logger.info(f"User transcription complete: {event.transcript}") def on_response_audio_done(self, event: ResponseAudioDone): - logger.info(f"Audio done for response ID {event.response_id}, item ID {event.item_id}") + logger.debug(f"Audio done for response ID {event.response_id}, item ID {event.item_id}") def on_response_audio_transcript_done(self, event: ResponseAudioTranscriptDone): - logger.info(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}") + logger.debug(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}") def on_response_content_part_done(self, event: ResponseContentPartDone): part_type = event.part.get("type") part_text = event.part.get("text", "") - logger.info(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}") + logger.debug(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}") def on_response_output_item_done(self, event: ResponseOutputItemDone): item_content = event.item.get("content", []) - logger.info(f"Output item done for response ID {event.response_id} with content: {item_content}") + if item_content: + for item in item_content: + if item.get("type") == "audio": + transcript = item.get("transcript") + if transcript: + logger.info(f"Assistant transcription complete: {transcript}") def on_response_done(self, event: ResponseDone): - logger.info(f"Response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'") + logger.debug(f"Assistant's response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'") def on_session_created(self, event: SessionCreated): logger.info(f"Session created: {event.session}") @@ -217,10 +216,14 @@ def on_session_updated(self, event: SessionUpdated): logger.info(f"Session updated: {event.session}") def on_input_audio_buffer_speech_started(self, event: InputAudioBufferSpeechStarted): - logger.info(f"Speech started at {event.audio_start_ms}ms for item ID {event.item_id}") + logger.info(f"Server VAD: User speech started at {event.audio_start_ms}ms for item ID {event.item_id}") + if self._client.options.turn_detection is not None: + self._client.clear_input_audio_buffer() + self._client.cancel_response() + self._audio_player.drain_and_restart() def on_response_output_item_added(self, event: ResponseOutputItemAdded): - logger.info(f"Output item added for response ID {event.response_id} with item: {event.item}") + logger.debug(f"Output item added for response ID {event.response_id} with item: {event.item}") if event.item.get("type") == "function_call": call_id = event.item.get("call_id") function_name = event.item.get("name") @@ -232,11 +235,9 @@ def on_response_output_item_added(self, event: ResponseOutputItemAdded): logger.warning("Function call item missing 'call_id' or 'name' fields.") def on_response_function_call_arguments_delta(self, event: ResponseFunctionCallArgumentsDelta): - logger.info(f"Function call arguments delta for call ID {event.call_id}: {event.delta}") + logger.debug(f"Function call arguments delta for call ID {event.call_id}: {event.delta}") def on_response_function_call_arguments_done(self, event: ResponseFunctionCallArgumentsDone): - logger.info(f"Function call arguments done for call ID {event.call_id} with arguments: {event.arguments}") - call_id = event.call_id arguments_str = event.arguments @@ -248,12 +249,16 @@ def on_response_function_call_arguments_done(self, event: ResponseFunctionCallAr return try: + self._function_processing = True + logger.info(f"Executing function '{function_name}' with arguments: {arguments_str} for call ID {call_id}") function_output = self._functions.execute(function_name, arguments_str) logger.info(f"Function output for call ID {call_id}: {function_output}") self._client.generate_response_from_function_call(call_id, function_output) except json.JSONDecodeError as e: logger.error(f"Failed to parse arguments for call ID {call_id}: {e}") return + finally: + self._function_processing = False def on_unhandled_event(self, event_type: str, event_data: Dict[str, Any]): logger.warning(f"Unhandled Event: {event_type} - {event_data}") @@ -311,13 +316,13 @@ def main(): api_key=api_key, model="gpt-4o-realtime-preview-2024-10-01", modalities=["audio", "text"], - instructions="You are a helpful assistant. Respond concisely. You have following functions to take screenshot 1) take_screenshot_and_save 2) take_screenshot_and_analyze", - turn_detection=get_vad_configuration(), + instructions="You are a helpful assistant. Respond concisely. You have access to a variety of tools to analyze, translate and review text and code.", + turn_detection=get_vad_configuration(use_server_vad=False), tools=functions.definitions, tool_choice="auto", temperature=0.8, max_output_tokens=None, - voice="ballad", + voice="sage", ) # Define AudioStreamOptions @@ -344,21 +349,20 @@ def main(): # Initialize AudioCapture with the event handler audio_capture = AudioCapture( event_handler=audio_capture_event_handler, - sample_rate=16000, + sample_rate=24000, channels=1, frames_per_buffer=1024, buffer_duration_sec=1.0, cross_fade_duration_ms=20, vad_parameters={ - "sample_rate": 16000, + "sample_rate": 24000, "chunk_size": 1024, - "window_duration": 1.0, + "window_duration": 1.5, "silence_ratio": 1.5, "min_speech_duration": 0.3, "min_silence_duration": 1.0 }, enable_wave_capture=False, - keyword_model_file="resources/kws.table", ) diff --git a/samples/sample_realtime_ai_with_local_vad.py b/samples/sample_realtime_ai_with_local_vad.py index 87feb10..fe6db18 100644 --- a/samples/sample_realtime_ai_with_local_vad.py +++ b/samples/sample_realtime_ai_with_local_vad.py @@ -16,13 +16,13 @@ # Configure logging logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] # Streaming logs to the console ) # Specific loggers for mentioned packages -logging.getLogger("utils.audio_playback").setLevel(logging.DEBUG) +logging.getLogger("utils.audio_playback").setLevel(logging.ERROR) logging.getLogger("utils.audio_capture").setLevel(logging.ERROR) logging.getLogger("utils.vad").setLevel(logging.ERROR) logging.getLogger("realtime_ai").setLevel(logging.ERROR) @@ -37,8 +37,7 @@ def __init__(self, client: RealtimeAIClient, event_handler: "MyRealtimeEventHand Initializes the event handler. :param client: Instance of RealtimeClient. - :param event_handler: Instance of MyRealtimeEventHandler - :param event_loop: The asyncio event loop. + :param event_handler: Instance of MyRealtimeEventHandler. """ self._client = client self._event_handler = event_handler @@ -49,40 +48,30 @@ def send_audio_data(self, audio_data: bytes): :param audio_data: Raw audio data in bytes. """ - logger.info("Sending audio data to the client.") + logger.debug("Sending audio data to the client.") self._client.send_audio(audio_data) def on_speech_start(self): """ Handles actions to perform when speech starts. - """ - logger.info("Speech has started.") - if self._event_handler.is_audio_playing(): - logger.info(f"User started speaking while audio is playing.") - - logger.info("Clearing input audio buffer.") + logger.info("Local VAD: User speech started") + if (self._client.options.turn_detection is None and + self._event_handler.is_audio_playing()): + logger.info("User started speaking while assistant is responding; interrupting the assistant's response.") self._client.clear_input_audio_buffer() - - logger.info("Cancelling response.") self._client.cancel_response() - - #current_item_id = self._event_handler.get_current_conversation_item_id() - #current_audio_content_index = self._event_handler.get_current_audio_content_id() - #logger.info(f"Truncate the current audio, current item ID: {current_item_id}, current audio content index: {current_audio_content_index}") - #self._client.truncate_response(item_id=current_item_id, content_index=current_audio_content_index, audio_end_ms=1000) - - # Restart the audio player self._event_handler.audio_player.drain_and_restart() def on_speech_end(self): """ Handles actions to perform when speech ends. """ - logger.info("Speech has ended") + logger.info("Local VAD: User speech ended") - logger.info("Requesting the client to generate a response.") - self._client.generate_response() + if self._client.options.turn_detection is None: + logger.debug("Using local VAD; requesting the client to generate a response after speech ends.") + self._client.generate_response() def on_keyword_detected(self, result): pass @@ -98,6 +87,7 @@ def __init__(self, audio_player: AudioPlayer, functions: FunctionTool): self._call_id_to_function_name = {} self._lock = threading.Lock() self._client = None + self._function_processing = False @property def audio_player(self): @@ -112,6 +102,9 @@ def get_current_audio_content_id(self): def is_audio_playing(self): return self._audio_player.is_audio_playing() + def is_function_processing(self): + return self._function_processing + def set_client(self, client: RealtimeAIClient): self._client = client @@ -119,53 +112,58 @@ def on_error(self, event: ErrorEvent): logger.error(f"Error occurred: {event.error.message}") def on_input_audio_buffer_speech_stopped(self, event: InputAudioBufferSpeechStopped): - logger.info(f"Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}") + logger.info(f"Server VAD: Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}") def on_input_audio_buffer_committed(self, event: InputAudioBufferCommitted): - logger.info(f"Audio Buffer Committed: {event.item_id}") + logger.debug(f"Audio Buffer Committed: {event.item_id}") def on_conversation_item_created(self, event: ConversationItemCreated): - logger.info(f"New Conversation Item: {event.item}") + logger.debug(f"New Conversation Item: {event.item}") def on_response_created(self, event: ResponseCreated): - logger.info(f"Response Created: {event.response}") + logger.debug(f"Response Created: {event.response}") def on_response_content_part_added(self, event: ResponseContentPartAdded): - logger.info(f"New Part Added: {event.part}") + logger.debug(f"New Part Added: {event.part}") def on_response_audio_delta(self, event: ResponseAudioDelta): - logger.info(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}") + logger.debug(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}") self._current_item_id = event.item_id self._current_audio_content_index = event.content_index self.handle_audio_delta(event) def on_response_audio_transcript_delta(self, event: ResponseAudioTranscriptDelta): - logger.info(f"Transcript Delta: {event.delta}") + logger.info(f"Assistant transcription delta: {event.delta}") def on_rate_limits_updated(self, event: RateLimitsUpdated): for rate in event.rate_limits: - logger.info(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}") + logger.debug(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}") def on_conversation_item_input_audio_transcription_completed(self, event: ConversationItemInputAudioTranscriptionCompleted): - logger.info(f"Transcription completed for item {event.item_id}: {event.transcript}") + logger.info(f"User transcription complete: {event.transcript}") def on_response_audio_done(self, event: ResponseAudioDone): - logger.info(f"Audio done for response ID {event.response_id}, item ID {event.item_id}") + logger.debug(f"Audio done for response ID {event.response_id}, item ID {event.item_id}") def on_response_audio_transcript_done(self, event: ResponseAudioTranscriptDone): - logger.info(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}") + logger.debug(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}") def on_response_content_part_done(self, event: ResponseContentPartDone): part_type = event.part.get("type") part_text = event.part.get("text", "") - logger.info(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}") + logger.debug(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}") def on_response_output_item_done(self, event: ResponseOutputItemDone): item_content = event.item.get("content", []) - logger.info(f"Output item done for response ID {event.response_id} with content: {item_content}") + if item_content: + for item in item_content: + if item.get("type") == "audio": + transcript = item.get("transcript") + if transcript: + logger.info(f"Assistant transcription complete: {transcript}") def on_response_done(self, event: ResponseDone): - logger.info(f"Response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'") + logger.debug(f"Assistant's response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'") def on_session_created(self, event: SessionCreated): logger.info(f"Session created: {event.session}") @@ -174,10 +172,14 @@ def on_session_updated(self, event: SessionUpdated): logger.info(f"Session updated: {event.session}") def on_input_audio_buffer_speech_started(self, event: InputAudioBufferSpeechStarted): - logger.info(f"Speech started at {event.audio_start_ms}ms for item ID {event.item_id}") + logger.info(f"Server VAD: User speech started at {event.audio_start_ms}ms for item ID {event.item_id}") + if self._client.options.turn_detection is not None: + self._client.clear_input_audio_buffer() + self._client.cancel_response() + self._audio_player.drain_and_restart() def on_response_output_item_added(self, event: ResponseOutputItemAdded): - logger.info(f"Output item added for response ID {event.response_id} with item: {event.item}") + logger.debug(f"Output item added for response ID {event.response_id} with item: {event.item}") if event.item.get("type") == "function_call": call_id = event.item.get("call_id") function_name = event.item.get("name") @@ -189,11 +191,9 @@ def on_response_output_item_added(self, event: ResponseOutputItemAdded): logger.warning("Function call item missing 'call_id' or 'name' fields.") def on_response_function_call_arguments_delta(self, event: ResponseFunctionCallArgumentsDelta): - logger.info(f"Function call arguments delta for call ID {event.call_id}: {event.delta}") + logger.debug(f"Function call arguments delta for call ID {event.call_id}: {event.delta}") def on_response_function_call_arguments_done(self, event: ResponseFunctionCallArgumentsDone): - logger.info(f"Function call arguments done for call ID {event.call_id} with arguments: {event.arguments}") - call_id = event.call_id arguments_str = event.arguments @@ -205,12 +205,16 @@ def on_response_function_call_arguments_done(self, event: ResponseFunctionCallAr return try: + self._function_processing = True + logger.info(f"Executing function '{function_name}' with arguments: {arguments_str} for call ID {call_id}") function_output = self._functions.execute(function_name, arguments_str) logger.info(f"Function output for call ID {call_id}: {function_output}") self._client.generate_response_from_function_call(call_id, function_output) except json.JSONDecodeError as e: logger.error(f"Failed to parse arguments for call ID {call_id}: {e}") return + finally: + self._function_processing = False def on_unhandled_event(self, event_type: str, event_data: Dict[str, Any]): logger.warning(f"Unhandled Event: {event_type} - {event_data}") @@ -269,7 +273,7 @@ def main(): model="gpt-4o-realtime-preview-2024-10-01", modalities=["audio", "text"], instructions="You are a helpful assistant. Respond concisely. If user asks to tell story, tell story very shortly.", - turn_detection=get_vad_configuration(), + turn_detection=get_vad_configuration(use_server_vad=False), tools=functions.definitions, tool_choice="auto", temperature=0.8, @@ -309,7 +313,7 @@ def main(): vad_parameters={ "sample_rate": 24000, "chunk_size": 1024, - "window_duration": 1.0, + "window_duration": 1.5, "silence_ratio": 1.5, "min_speech_duration": 0.3, "min_silence_duration": 1.0 diff --git a/samples/user_functions.py b/samples/user_functions.py index ae652e6..2aa6845 100644 --- a/samples/user_functions.py +++ b/samples/user_functions.py @@ -56,26 +56,41 @@ def send_email(recipient: str, subject: str, body: str) -> str: return message_json -def take_screenshot_and_analyze(user_input: str) -> str: +def _generate_chat_completion(ai_client, model, messages): + print(f"generate_chat_completion, messages: {messages}") + print(f"generate_chat_completion, model: {model}") + + try: + # Generate the chat completion + response = ai_client.chat.completions.create( + model=model, + messages=messages + ) + print(f"generate_chat_completion, response: {response}") + + # Extract the content of the first choice + if response.choices and response.choices[0].message: + message_content = response.choices[0].message.content + else: + message_content = "No response" + + return json.dumps({"result": message_content}) + except Exception as e: + error_message = f"Failed to generate chat completion: {str(e)}" + print(error_message) + return json.dumps({"function_error": error_message}) + + +def _screenshot_to_bytes() -> bytes: """ - Captures a screenshot, sends it to the specified OpenAI model for analysis, - and returns the analysis result. + Captures a screenshot and returns it as binary data. - :param user_input (str): User input request as it was given by user for screenshot analysis and actions. - - :return: The analysis result as a JSON string. - :rtype: str + :return: The screenshot as binary data. + :rtype: bytes """ from PIL import Image import mss - try: - openai_client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) - except Exception as e: - print(f"Error initializing OpenAI client: {e}") - return None - - # Capture a screenshot with mss.mss() as sct: monitor = sct.monitors[0] # 0 is the first monitor; adjust if multiple monitors are used screenshot = sct.grab(monitor) @@ -86,19 +101,53 @@ def take_screenshot_and_analyze(user_input: str) -> str: img.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) img_bytes = img_byte_arr.read() + return img_bytes - # Encode the image in base64 - img_base64 = base64.b64encode(img_bytes).decode('utf-8') - # Prepare the payload for OpenAI API - content = [] - content.append({"role": "system", "content": "Provide an answer that focuses solely on the information requested, avoiding personal references or perspectives. Ensure the response is objective and directly addresses the question or topic"}) - content.append({"type": "text", "text": user_input}) - content = [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}", "detail": "high"}}] - messages = [{"role": "user", "content": content}] +def _analyze_image(img_base64: str, system_input: str, user_input: str, filename: str) -> str: + """ + Analyzes the given image and returns the analysis result. + + :param img_base64 (str): Base64 encoded image data. + :param system_input (str): System input for the analysis. + :param user_input (str): User input for the analysis. + :return: The analysis result. + :rtype: str + """ + + try: + openai_client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) + except Exception as e: + print(f"Error initializing OpenAI client: {e}") + return None + + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": system_input + } + ], + "role": "user", + "content": [ + { + "type": "text", + "text": user_input + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{img_base64}", + "detail": "high" + } + }, + ], + } + ] try: - # Call OpenAI's ChatCompletion API with the image response = openai_client.chat.completions.create( model="gpt-4o", messages=messages, @@ -111,16 +160,93 @@ def take_screenshot_and_analyze(user_input: str) -> str: print(f"User input: {user_input}") print(f"Analysis: {analysis}") + # Create user message using the user input and the analysis result + #user_message = f"User input: {user_input}\nAnalysis: {analysis}" + #o1_messages = [{"role": "user", "content": user_message}] + #o1_response = _generate_chat_completion(openai_client, "o1-mini", o1_messages) + #print(f"O1 response: {o1_response}") + # Show the analysis result in code - with open("analysis.md", "w") as f: + with open(filename, "w") as f: f.write(analysis) - os.system("code analysis.md") + os.system(f"code {filename}") return json.dumps({"analysis": analysis}) - + except Exception as e: - print(f"An error occurred: {e}") - return None + error_message = f"An error occurred: {e}" + print(error_message) + return json.dumps({"function_error": error_message}) + + +def review_highlighted_code() -> str: + """ + Captures a screenshot, sends it to the specified OpenAI model for analysis, + and returns the analysis result. + + :return: The analysis result as a JSON string. + :rtype: str + """ + # Capture a screenshot and convert it to base64 + img_bytes = _screenshot_to_bytes() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + return _analyze_image(img_base64=img_base64, + system_input="You are expert in analyzing images to text. If the image contains highlighted part, focus on that.", + user_input="Review the highlighted code and provide detailed feedback.", + filename="highlighted_code_analysis.md") + + +def translate_highlighted_text(language: str) -> str: + """ + Captures a screenshot, sends it to the specified OpenAI model for analysis, + and returns the analysis result. + + :return: The analysis result as a JSON string. + :rtype: str + """ + # Capture a screenshot and convert it to base64 + img_bytes = _screenshot_to_bytes() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + return _analyze_image(img_base64=img_base64, + system_input=f"You are expert in translating text to different languages.", + user_input=f"Translate the highlighted text to {language}.", + filename="highlighted_text_translation.md") + + +def explain_highlighted_text() -> str: + """ + Captures a screenshot, sends it to the specified OpenAI model for analysis, + and returns the analysis result. + + :return: The analysis result as a JSON string. + :rtype: str + """ + # Capture a screenshot and convert it to base64 + img_bytes = _screenshot_to_bytes() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + return _analyze_image(img_base64=img_base64, + system_input="You are expert in explaining text. If the image contains highlighted text, provide the explanation of that.", + user_input="Explain the highlighted text in detail, in understandable language.", + filename="highlighted_text_explanation.md") + + +def take_screenshot_and_analyze(user_input: str) -> str: + """ + Captures a screenshot, sends it to the specified OpenAI model for analysis, + and returns the analysis result. + + :param user_input (str): User input request as it was given by user for screenshot analysis and actions. + + :return: The analysis result as a JSON string. + :rtype: str + """ + # Capture a screenshot and convert it to base64 + img_bytes = _screenshot_to_bytes() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + return _analyze_image(img_base64=img_base64, + system_input="Analyze the screenshot and provide all details from it. If the image contains e.g. code or highlighted parts, provide the exact analysis of that.", + user_input=user_input, + filename="screenshot_analysis.md") def take_screenshot_and_show() -> str: @@ -151,5 +277,8 @@ def take_screenshot_and_show() -> str: fetch_weather, send_email, take_screenshot_and_analyze, - take_screenshot_and_show + take_screenshot_and_show, + review_highlighted_code, + translate_highlighted_text, + explain_highlighted_text } diff --git a/samples/utils/audio_capture.py b/samples/utils/audio_capture.py index 93a33aa..b97ef68 100644 --- a/samples/utils/audio_capture.py +++ b/samples/utils/audio_capture.py @@ -84,58 +84,57 @@ def __init__( :param buffer_duration_sec: Duration of the internal audio buffer in seconds. :param cross_fade_duration_ms: Duration for cross-fading in milliseconds. :param vad_parameters: Parameters for VoiceActivityDetector. + :param enable_wave_capture: Flag to enable wave file capture. + :param keyword_model_file: Path to the keyword recognition model file. """ self.event_handler = event_handler self.sample_rate = sample_rate self.channels = channels self.frames_per_buffer = frames_per_buffer - self.buffer_duration_sec = buffer_duration_sec - self.buffer_size = int(self.buffer_duration_sec * self.sample_rate) - self.audio_buffer = np.zeros(self.buffer_size, dtype=np.int16) - self.buffer_pointer = 0 self.cross_fade_duration_ms = cross_fade_duration_ms - self.cross_fade_samples = int((self.cross_fade_duration_ms / 1000) * self.sample_rate) - self.speech_started = False self.enable_wave_capture = enable_wave_capture + self.vad = None + self.speech_started = False + if self.enable_wave_capture: try: self.wave_file = wave.open("microphone_output.wav", "wb") - self.wave_file.setnchannels(channels) + self.wave_file.setnchannels(self.channels) self.wave_file.setsampwidth(pyaudio.PyAudio().get_sample_size(FORMAT)) - self.wave_file.setframerate(sample_rate) + self.wave_file.setframerate(self.sample_rate) + logger.info("Wave file initialized for capture.") except Exception as e: logger.error(f"Error opening wave file: {e}") + self.enable_wave_capture = False # Initialize VAD - if vad_parameters is None: - vad_parameters = { - "sample_rate": self.sample_rate, - "chunk_size": self.frames_per_buffer, - "window_duration": 1.0, - "silence_ratio": 1.5, - "min_speech_duration": 0.3, - "min_silence_duration": 0.3 - } - try: - self.vad = VoiceActivityDetector(**vad_parameters) - logger.info("VoiceActivityDetector initialized with parameters: " - f"{vad_parameters}") - except Exception as e: - logger.error(f"Failed to initialize VoiceActivityDetector: {e}") - raise + if vad_parameters is not None: + try: + self.vad = VoiceActivityDetector(**vad_parameters) + logger.info(f"VoiceActivityDetector initialized with parameters: {vad_parameters}") + self.buffer_duration_sec = buffer_duration_sec + self.buffer_size = int(self.buffer_duration_sec * self.sample_rate) + self.audio_buffer = np.zeros(self.buffer_size, dtype=np.int16) + self.buffer_pointer = 0 + self.cross_fade_samples = int((self.cross_fade_duration_ms / 1000) * self.sample_rate) + except Exception as e: + logger.error(f"Failed to initialize VoiceActivityDetector: {e}") + self.vad = None - # Initialize keyword recognizer if model file is provided self.keyword_recognizer = None if keyword_model_file: - self.keyword_recognizer = AzureKeywordRecognizer( - model_file=keyword_model_file, - callback=self._on_keyword_detected, - sample_rate=sample_rate, - channels=channels - ) - self.keyword_recognizer.start_recognition() - logger.info("Keyword recognizer initialized.") + try: + self.keyword_recognizer = AzureKeywordRecognizer( + model_file=keyword_model_file, + callback=self._on_keyword_detected, + sample_rate=sample_rate, + channels=channels + ) + self.keyword_recognizer.start_recognition() + logger.info("Keyword recognizer initialized.") + except Exception as e: + logger.error(f"Failed to initialize AzureKeywordRecognizer: {e}") # Initialize PyAudio for input self.p = pyaudio.PyAudio() @@ -167,66 +166,80 @@ def handle_input_audio(self, indata: bytes, frame_count: int, time_info, status) """ if status: logger.warning(f"Input Stream Status: {status}") + + try: + audio_data = np.frombuffer(indata, dtype=np.int16).copy() + except ValueError as e: + logger.error(f"Error converting audio data: {e}") + return (None, pyaudio.paContinue) + + if self.vad is None: + self.event_handler.send_audio_data(indata) + if self.enable_wave_capture: + try: + self.wave_file.writeframes(indata) + except Exception as e: + logger.error(f"Error writing to wave file: {e}") + return (None, pyaudio.paContinue) - # Convert bytes to numpy int16 and make sure the array is writable - audio_data = np.frombuffer(indata, dtype=np.int16).copy() - - # Update internal audio buffer - self.buffer_pointer = self._update_buffer(audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size) - current_buffer = self._get_buffer_content(self.audio_buffer, self.buffer_pointer, self.buffer_size).copy() - - # Process VAD to detect speech try: speech_detected, is_speech = self.vad.process_audio_chunk(audio_data) - # Push audio data to keyword recognizer if self.keyword_recognizer: - self.keyword_recognizer.push_audio(audio_data.tobytes()) + self.keyword_recognizer.push_audio(audio_data) except Exception as e: logger.error(f"Error processing VAD: {e}") speech_detected, is_speech = False, False - # Synchronously handle audio if speech_detected or self.speech_started: if is_speech: if not self.speech_started: logger.info("Speech started") + self.buffer_pointer = self._update_buffer(audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size) + current_buffer = self._get_buffer_content(self.audio_buffer, self.buffer_pointer, self.buffer_size).copy() - # Determine fade length for crossfading fade_length = min(self.cross_fade_samples, len(current_buffer), len(audio_data)) - fade_out = np.linspace(1.0, 0.0, fade_length, dtype=np.float32) - fade_in = np.linspace(0.0, 1.0, fade_length, dtype=np.float32) - if fade_length > 0: + fade_out = np.linspace(1.0, 0.0, fade_length, dtype=np.float32) + fade_in = np.linspace(0.0, 1.0, fade_length, dtype=np.float32) + buffer_fade_section = current_buffer[-fade_length:].astype(np.float32) audio_fade_section = audio_data[:fade_length].astype(np.float32) faded_buffer_section = buffer_fade_section * fade_out faded_audio_section = audio_fade_section * fade_in - # Ensure that the slices are writable current_buffer[-fade_length:] = np.round(faded_buffer_section).astype(np.int16) audio_data[:fade_length] = np.round(faded_audio_section).astype(np.int16) - # Combine buffered and current audio - combined_audio = np.concatenate((current_buffer, audio_data)) + combined_audio = np.concatenate((current_buffer, audio_data)) + else: + combined_audio = audio_data logger.info("Sending buffered audio to client via event handler...") self.event_handler.on_speech_start() self.event_handler.send_audio_data(combined_audio.tobytes()) if self.enable_wave_capture: - self.wave_file.writeframes(combined_audio.tobytes()) + try: + self.wave_file.writeframes(combined_audio.tobytes()) + except Exception as e: + logger.error(f"Error writing to wave file: {e}") else: logger.info("Sending audio to client via event handler...") self.event_handler.send_audio_data(audio_data.tobytes()) if self.enable_wave_capture: - self.wave_file.writeframes(audio_data.tobytes()) + try: + self.wave_file.writeframes(audio_data.tobytes()) + except Exception as e: + logger.error(f"Error writing to wave file: {e}") self.speech_started = True else: logger.info("Speech ended") self.event_handler.on_speech_end() - #self.vad.reset() # Reset VAD if necessary self.speech_started = False + if self.vad: + self.buffer_pointer = self._update_buffer(audio_data, self.audio_buffer, self.buffer_pointer, self.buffer_size) + return (None, pyaudio.paContinue) def _update_buffer(self, new_audio: np.ndarray, buffer: np.ndarray, pointer: int, buffer_size: int) -> int: @@ -241,18 +254,21 @@ def _update_buffer(self, new_audio: np.ndarray, buffer: np.ndarray, pointer: int """ new_length = len(new_audio) if new_length >= buffer_size: - buffer[:] = new_audio[-buffer_size:] # Keep only last BUFFER_SIZE samples + buffer[:] = new_audio[-buffer_size:] pointer = 0 + logger.debug("Buffer overwritten with new audio data.") else: end_space = buffer_size - pointer if new_length <= end_space: buffer[pointer:pointer + new_length] = new_audio pointer += new_length + logger.debug(f"Buffer updated. New pointer position: {pointer}") else: buffer[pointer:] = new_audio[:end_space] remaining = new_length - end_space buffer[:remaining] = new_audio[end_space:] pointer = remaining + logger.debug(f"Buffer wrapped around. New pointer position: {pointer}") return pointer def _get_buffer_content(self, buffer: np.ndarray, pointer: int, buffer_size: int) -> np.ndarray: @@ -265,7 +281,9 @@ def _get_buffer_content(self, buffer: np.ndarray, pointer: int, buffer_size: int :return: Ordered audio data as a NumPy array. """ if pointer == 0: + logger.debug("Buffer content retrieved without wrapping.") return buffer.copy() + logger.debug("Buffer content retrieved with wrapping.") return np.concatenate((buffer[pointer:], buffer[:pointer])) def _on_keyword_detected(self, result): @@ -273,23 +291,37 @@ def _on_keyword_detected(self, result): Internal callback when a keyword is detected. """ logger.info("Keyword detected") - self.keyword_recognizer.stop_recognition() - self.event_handler.on_keyword_detected(result) - self.keyword_recognizer.start_recognition() + if self.keyword_recognizer: + try: + self.keyword_recognizer.stop_recognition() + self.event_handler.on_keyword_detected(result) + self.keyword_recognizer.start_recognition() + logger.debug("Keyword recognizer restarted after detection.") + except Exception as e: + logger.error(f"Error handling keyword detection: {e}") def close(self): """ Closes the audio capture stream and the wave file, releasing all resources. """ try: - self.stream.stop_stream() - self.stream.close() - self.p.terminate() - logger.info("AudioCapture stopped and input stream closed.") + if hasattr(self, 'stream'): + self.stream.stop_stream() + self.stream.close() + logger.info("Audio input stream stopped and closed.") + + if hasattr(self, 'p'): + self.p.terminate() + logger.info("PyAudio terminated.") - if self.enable_wave_capture and self.wave_file: + if self.enable_wave_capture and hasattr(self, 'wave_file'): self.wave_file.close() logger.info("Wave file saved successfully.") + if self.keyword_recognizer: + self.keyword_recognizer.stop_recognition() + logger.info("Keyword recognizer stopped.") + + logger.info("AudioCapture resources have been released.") except Exception as e: logger.error(f"Error closing AudioCapture: {e}") diff --git a/samples/utils/azure_keyword_recognizer.py b/samples/utils/azure_keyword_recognizer.py index d5d6dda..10485c4 100644 --- a/samples/utils/azure_keyword_recognizer.py +++ b/samples/utils/azure_keyword_recognizer.py @@ -1,4 +1,42 @@ import azure.cognitiveservices.speech as speechsdk +from scipy.signal import resample_poly +import numpy as np + + +def convert_sample_rate(audio_data: np.ndarray, orig_sr: int = 24000, target_sr: int = 16000) -> np.ndarray: + """ + Converts the sample rate of the given audio data from orig_sr to target_sr using polyphase filtering. + + Parameters: + - audio_data: np.ndarray + The input audio data as a NumPy array of type int16. + - orig_sr: int + Original sample rate of the audio data. + - target_sr: int + Desired sample rate after conversion. + + Returns: + - np.ndarray + The resampled audio data as a NumPy array of type int16. + """ + from math import gcd + divisor = gcd(orig_sr, target_sr) + up = target_sr // divisor + down = orig_sr // divisor + + # Convert to float for high-precision processing + audio_float = audio_data.astype(np.float32) + + # Perform resampling + resampled_float = resample_poly(audio_float, up, down) + + # Ensure the resampled data is within int16 range + resampled_float = np.clip(resampled_float, -32768, 32767) + + # Convert back to int16 + resampled_int16 = resampled_float.astype(np.int16) + + return resampled_int16 class AzureKeywordRecognizer: @@ -6,7 +44,7 @@ class AzureKeywordRecognizer: A class to recognize specific keywords from PCM audio streams using Azure Cognitive Services. """ - def __init__(self, model_file: str, callback, sample_rate: int = 16000, bits_per_sample: int = 16, channels: int = 1): + def __init__(self, model_file: str, callback, sample_rate: int = 16000, channels: int = 1): """ Initializes the AzureKeywordRecognizer. @@ -15,8 +53,17 @@ def __init__(self, model_file: str, callback, sample_rate: int = 16000, bits_per """ # Create a push stream to which we'll write PCM audio data - audio_stream_format = speechsdk.audio.AudioStreamFormat(samples_per_second=sample_rate, bits_per_sample=bits_per_sample, channels=channels) - self.audio_stream = speechsdk.audio.PushAudioInputStream(stream_format=audio_stream_format) + self.sample_rate = sample_rate + self.channels = channels + + # Validate the sample rate is either 16000 or 24000 + if sample_rate not in [16000, 24000]: + raise ValueError("Invalid sample rate. Supported rates are 16000 and 24000.") + # Validate the number of channels is 1 + if channels != 1: + raise ValueError("Invalid number of channels. Only mono audio is supported.") + + self.audio_stream = speechsdk.audio.PushAudioInputStream() self.audio_config = speechsdk.audio.AudioConfig(stream=self.audio_stream) # Initialize the speech recognizer @@ -56,9 +103,13 @@ def push_audio(self, pcm_data): """ Pushes PCM audio data to the recognizer. - :param pcm_data: Bytes of PCM audio data. + :param pcm_data: Numpy array of PCM audio samples. """ - self.audio_stream.write(pcm_data) + if self.sample_rate == 24000: + converted_audio = convert_sample_rate(pcm_data, orig_sr=24000, target_sr=16000) + self.audio_stream.write(converted_audio.tobytes()) + else: + self.audio_stream.write(pcm_data.tobytes()) def _on_recognized(self, event: speechsdk.SpeechRecognitionEventArgs): """ diff --git a/src/realtime_ai/aio/realtime_ai_client.py b/src/realtime_ai/aio/realtime_ai_client.py index 23a803b..70b8273 100644 --- a/src/realtime_ai/aio/realtime_ai_client.py +++ b/src/realtime_ai/aio/realtime_ai_client.py @@ -62,6 +62,29 @@ async def send_audio(self, audio_data: bytes): logger.info("RealtimeAIClient: Queuing audio data for streaming.") await self.audio_stream_manager.write_audio_buffer(audio_data) + async def send_text(self, text: str): + """Sends text input to the service manager. + """ + event = { + "event_id": self.service_manager._generate_event_id(), + "type": "conversation.item.create", + "item": { + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": text + } + ] + } + } + await self.service_manager.send_event(event) + logger.info("RealtimeAIClient: Sent text input to server.") + # Using server VAD; requesting the client to generate a response after text input. + if self._options.turn_detection: + await self.generate_response() + async def generate_response(self): """Sends a response.create event to generate a response.""" logger.info("RealtimeAIClient: Generating response.") diff --git a/src/realtime_ai/realtime_ai_client.py b/src/realtime_ai/realtime_ai_client.py index a1b98ed..dd988de 100644 --- a/src/realtime_ai/realtime_ai_client.py +++ b/src/realtime_ai/realtime_ai_client.py @@ -71,6 +71,9 @@ def send_text(self, text: str): } self._send_event_to_manager(event) logger.info("RealtimeAIClient: Sent text input to server.") + # Using server VAD; requesting the client to generate a response after text input. + if self._options.turn_detection: + self.generate_response() def generate_response(self): """Sends a response.create event to generate a response."""