Skip to content

Commit

Permalink
Merge pull request #8 from jhakulin/jhakulin/sample-updates
Browse files Browse the repository at this point in the history
Update samples, logs, keyword detector, audio capture
  • Loading branch information
jhakulin authored Nov 8, 2024
2 parents 5f9e7a2 + 0831a94 commit 188d613
Show file tree
Hide file tree
Showing 8 changed files with 502 additions and 256 deletions.
118 changes: 59 additions & 59 deletions samples/async/sample_realtime_ai_with_local_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

# Configure logging
logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()] # Streaming logs to the console
)

# Specific loggers for mentioned packages
logging.getLogger("utils.audio_playback").setLevel(logging.DEBUG)
logging.getLogger("utils.audio_playback").setLevel(logging.ERROR)
logging.getLogger("utils.audio_capture").setLevel(logging.ERROR)
logging.getLogger("utils.vad").setLevel(logging.ERROR)
logging.getLogger("realtime_ai").setLevel(logging.ERROR)
Expand All @@ -46,54 +46,43 @@ def __init__(self, client: RealtimeAIClient, event_handler: "MyRealtimeEventHand
:param event_handler: Instance of MyRealtimeEventHandler
:param event_loop: The asyncio event loop.
"""
self.client = client
self.event_handler = event_handler
self.event_loop = event_loop
self.cancelled = False
self._client = client
self._event_handler = event_handler
self._event_loop = event_loop

def send_audio_data(self, audio_data: bytes):
"""
Sends audio data to the RealtimeClient.
:param audio_data: Raw audio data in bytes.
"""
logger.info("Sending audio data to the client.")
asyncio.run_coroutine_threadsafe(self.client.send_audio(audio_data), self.event_loop)
logger.debug("Sending audio data to the client.")
asyncio.run_coroutine_threadsafe(self._client.send_audio(audio_data), self._event_loop)

def on_speech_start(self):
"""
Handles actions to perform when speech starts.
"""
logger.info("Speech has started.")
if self.event_handler.is_audio_playing():
logger.info(f"User started speaking while audio is playing.")

logger.info("Clearing input audio buffer.")
asyncio.run_coroutine_threadsafe(self.client.clear_input_audio_buffer(), self.event_loop)

logger.info("Cancelling response.")
asyncio.run_coroutine_threadsafe(self.client.cancel_response(), self.event_loop)
self.cancelled = True

#current_item_id = self.event_handler.get_current_conversation_item_id()
#current_audio_content_index = self.event_handler.get_current_audio_content_id()
#logger.info(f"Truncate the current audio, current item ID: {current_item_id}, current audio content index: {current_audio_content_index}")
#asyncio.run_coroutine_threadsafe(self.client.truncate_response(item_id=current_item_id, content_index=current_audio_content_index, audio_end_ms=1000), self.event_loop)

# Restart the audio player
self.event_handler.audio_player.drain_and_restart()
else:
logger.info("Assistant is not speaking, cancelling response is not required.")
self.cancelled = False
logger.info("Local VAD: User speech started")
if (self._client.options.turn_detection is None and
self._event_handler.is_audio_playing()):
logger.info("User started speaking while assistant is responding; interrupting the assistant's response.")
asyncio.run_coroutine_threadsafe(self._client.clear_input_audio_buffer(), self._event_loop)
asyncio.run_coroutine_threadsafe(self._client.cancel_response(), self._event_loop)
self._event_handler.audio_player.drain_and_restart()

def on_speech_end(self):
"""
Handles actions to perform when speech ends.
"""
logger.info("Speech has ended")
logger.info("Requesting the client to generate a response.")
asyncio.run_coroutine_threadsafe(self.client.generate_response(), self.event_loop)
logger.info("Local VAD: User speech ended")

if self._client.options.turn_detection is None:
logger.debug("Using local VAD; requesting the client to generate a response after speech ends.")
asyncio.run_coroutine_threadsafe(self._client.generate_response(), self._event_loop)

def on_keyword_detected(self, result):
pass


class MyRealtimeEventHandler(RealtimeAIEventHandler):
Expand All @@ -106,6 +95,7 @@ def __init__(self, audio_player: AudioPlayer, functions: FunctionTool):
self._current_audio_content_index = None
self._call_id_to_function_name = {}
self._functions = functions
self._function_processing = False

@property
def audio_player(self):
Expand All @@ -120,60 +110,68 @@ def get_current_audio_content_id(self):
def is_audio_playing(self):
return self._audio_player.is_audio_playing()

def is_function_processing(self):
return self._function_processing

def set_client(self, client: RealtimeAIClient):
self._client = client

async def on_error(self, event: ErrorEvent) -> None:
logger.error(f"Error occurred: {event.error.message}")

async def on_input_audio_buffer_speech_stopped(self, event: InputAudioBufferSpeechStopped) -> None:
logger.info(f"Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}")
logger.info(f"Server VAD: Speech stopped at {event.audio_end_ms}ms, Item ID: {event.item_id}")

async def on_input_audio_buffer_committed(self, event: InputAudioBufferCommitted) -> None:
logger.info(f"Audio Buffer Committed: {event.item_id}")
logger.debug(f"Audio Buffer Committed: {event.item_id}")

async def on_conversation_item_created(self, event: ConversationItemCreated) -> None:
logger.info(f"New Conversation Item: {event.item}")
logger.debug(f"New Conversation Item: {event.item}")

async def on_response_created(self, event: ResponseCreated) -> None:
logger.info(f"Response Created: {event.response}")
logger.debug(f"Response Created: {event.response}")

async def on_response_content_part_added(self, event: ResponseContentPartAdded) -> None:
logger.info(f"New Part Added: {event.part}")
logger.debug(f"New Part Added: {event.part}")

async def on_response_audio_delta(self, event: ResponseAudioDelta) -> None:
logger.info(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}")
logger.debug(f"Received audio delta for Response ID {event.response_id}, Item ID {event.item_id}, Content Index {event.content_index}")
self._current_item_id = event.item_id
self._current_audio_content_index = event.content_index
self.handle_audio_delta(event)

async def on_response_audio_transcript_delta(self, event: ResponseAudioTranscriptDelta) -> None:
logger.info(f"Transcript Delta: {event.delta}")
logger.info(f"Assistant transcription delta: {event.delta}")

async def on_rate_limits_updated(self, event: RateLimitsUpdated) -> None:
for rate in event.rate_limits:
logger.info(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}")
logger.debug(f"Rate Limit: {rate.name}, Remaining: {rate.remaining}")

async def on_conversation_item_input_audio_transcription_completed(self, event: ConversationItemInputAudioTranscriptionCompleted) -> None:
logger.info(f"Transcription completed for item {event.item_id}: {event.transcript}")
logger.info(f"User transcription complete: {event.transcript}")

async def on_response_audio_done(self, event: ResponseAudioDone) -> None:
logger.info(f"Audio done for response ID {event.response_id}, item ID {event.item_id}")
logger.debug(f"Audio done for response ID {event.response_id}, item ID {event.item_id}")

async def on_response_audio_transcript_done(self, event: ResponseAudioTranscriptDone) -> None:
logger.info(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}")
logger.debug(f"Audio transcript done: '{event.transcript}' for response ID {event.response_id}")

async def on_response_content_part_done(self, event: ResponseContentPartDone) -> None:
part_type = event.part.get("type")
part_text = event.part.get("text", "")
logger.info(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}")
logger.debug(f"Content part done: '{part_text}' of type '{part_type}' for response ID {event.response_id}")

async def on_response_output_item_done(self, event: ResponseOutputItemDone) -> None:
item_content = event.item.get("content", [])
logger.info(f"Output item done for response ID {event.response_id} with content: {item_content}")
if item_content:
for item in item_content:
if item.get("type") == "audio":
transcript = item.get("transcript")
if transcript:
logger.info(f"Assistant transcription complete: {transcript}")

async def on_response_done(self, event: ResponseDone) -> None:
logger.info(f"Response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'")
logger.debug(f"Assistant's response completed with status '{event.response.get('status')}' and ID '{event.response.get('id')}'")

async def on_session_created(self, event: SessionCreated) -> None:
logger.info(f"Session created: {event.session}")
Expand All @@ -182,15 +180,17 @@ async def on_session_updated(self, event: SessionUpdated) -> None:
logger.info(f"Session updated: {event.session}")

async def on_input_audio_buffer_speech_started(self, event: InputAudioBufferSpeechStarted) -> None:
logger.info(f"Speech started at {event.audio_start_ms}ms for item ID {event.item_id}")
logger.info(f"Server VAD: User speech started at {event.audio_start_ms}ms for item ID {event.item_id}")
if self._client.options.turn_detection is not None:
await self._client.clear_input_audio_buffer()
await self._client.cancel_response()
await asyncio.threads.to_thread(self._audio_player.drain_and_restart)

async def on_response_output_item_added(self, event: ResponseOutputItemAdded) -> None:
logger.info(f"Output item added for response ID {event.response_id} with item: {event.item}")

logger.debug(f"Output item added for response ID {event.response_id} with item: {event.item}")
if event.item.get("type") == "function_call":
call_id = event.item.get("call_id")
function_name = event.item.get("name")

if call_id and function_name:
# Properly acquire the lock with 'await' and spread the usage over two lines
await self._lock.acquire() # Wait until the lock is available, then acquire it
Expand All @@ -204,11 +204,9 @@ async def on_response_output_item_added(self, event: ResponseOutputItemAdded) ->
logger.warning("Function call item missing 'call_id' or 'name' fields.")

async def on_response_function_call_arguments_delta(self, event: ResponseFunctionCallArgumentsDelta) -> None:
logger.info(f"Function call arguments delta for call ID {event.call_id}: {event.delta}")
logger.debug(f"Function call arguments delta for call ID {event.call_id}: {event.delta}")

async def on_response_function_call_arguments_done(self, event: ResponseFunctionCallArgumentsDone) -> None:
logger.info(f"Function call arguments done for call ID {event.call_id} with arguments: {event.arguments}")

call_id = event.call_id
arguments_str = event.arguments

Expand All @@ -225,14 +223,16 @@ async def on_response_function_call_arguments_done(self, event: ResponseFunction
return

try:
self._function_processing = True
logger.info(f"Executing function '{function_name}' with arguments: {arguments_str} for call ID {call_id}")
function_output = await asyncio.threads.to_thread(self._functions.execute, function_name, arguments_str)
logger.info(f"Function output for call ID {call_id}: {function_output}")

# Assuming generate_response_from_function_call is an async method
await self._client.generate_response_from_function_call(call_id, function_output)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse arguments for call ID {call_id}: {e}")
return
finally:
self._function_processing = False

def on_unhandled_event(self, event_type: str, event_data: Dict[str, Any]):
logger.warning(f"Unhandled Event: {event_type} - {event_data}")
Expand Down Expand Up @@ -291,7 +291,7 @@ async def main():
model="gpt-4o-realtime-preview-2024-10-01",
modalities=["audio", "text"],
instructions="You are a helpful assistant. Respond concisely. If user asks to tell story, tell story very shortly.",
turn_detection=get_vad_configuration(),
turn_detection=get_vad_configuration(use_server_vad=False),
tools=functions.definitions,
tool_choice="auto",
temperature=0.8,
Expand All @@ -315,7 +315,7 @@ async def main():
event_handler.set_client(client)
await client.start()

loop = asyncio.get_running_loop() # Get the current event loop
loop = asyncio.get_running_loop()

audio_capture_event_handler = MyAudioCaptureEventHandler(
client=client,
Expand All @@ -334,7 +334,7 @@ async def main():
vad_parameters={
"sample_rate": 24000,
"chunk_size": 1024,
"window_duration": 1.0,
"window_duration": 1.5,
"silence_ratio": 1.5,
"min_speech_duration": 0.3,
"min_silence_duration": 1.0
Expand Down
Loading

0 comments on commit 188d613

Please sign in to comment.