From c262b272fa672afe8c592db4afa075e1ac9c3161 Mon Sep 17 00:00:00 2001 From: chadbailey59 Date: Mon, 23 Sep 2024 14:51:17 -0500 Subject: [PATCH] Added RTVIActionFrame (#464) * added RTVIActionFrame * server-sent events * reverted log changes * fixup --- src/pipecat/processors/frameworks/rtvi.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 0450102a7..820ea716c 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -8,12 +8,14 @@ from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, ValidationError +from dataclasses import dataclass from pipecat.frames.frames import ( BotInterruptionFrame, BotStartedSpeakingFrame, BotStoppedSpeakingFrame, CancelFrame, + DataFrame, EndFrame, ErrorFrame, Frame, @@ -119,6 +121,12 @@ class RTVIActionRun(BaseModel): arguments: Optional[List[RTVIActionRunArgument]] = None +@dataclass +class RTVIActionFrame(DataFrame): + rtvi_action_run: RTVIActionRun + message_id: Optional[str] = None + + class RTVIMessage(BaseModel): label: Literal["rtvi-ai"] = "rtvi-ai" type: str @@ -376,6 +384,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) elif isinstance(frame, TransportMessageFrame): await self._message_queue.put(frame) + elif isinstance(frame, RTVIActionFrame): + await self._handle_action(frame.message_id, frame.rtvi_action_run) # Other frames else: await self.push_frame(frame, direction) @@ -548,7 +558,7 @@ async def _handle_function_call_result(self, data): ) await self.push_frame(frame) - async def _handle_action(self, request_id: str, data: RTVIActionRun): + async def _handle_action(self, request_id: str | None, data: RTVIActionRun): action_id = self._action_id(data.service, data.action) if action_id not in self._registered_actions: await self._send_error_response(request_id, f"Action {action_id} not registered") @@ -559,8 +569,11 @@ async def _handle_action(self, request_id: str, data: RTVIActionRun): for arg in data.arguments: arguments[arg.name] = arg.value result = await action.handler(self, action.service, arguments) - message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result)) - await self._push_transport_message(message) + # Only send a response if request_id is present. Things that don't care about + # action responses (such as webhooks) don't set a request_id + if request_id: + message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result)) + await self._push_transport_message(message) async def _maybe_send_bot_ready(self): if self._pipeline_started and self._client_ready: