Skip to content

Commit

Permalink
Added RTVIActionFrame (#464)
Browse files Browse the repository at this point in the history
* added RTVIActionFrame

* server-sent events

* reverted log changes

* fixup
  • Loading branch information
chadbailey59 authored Sep 23, 2024
1 parent 9ef9c1c commit c262b27
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/pipecat/processors/frameworks/rtvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
from dataclasses import dataclass

from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
DataFrame,
EndFrame,
ErrorFrame,
Frame,
Expand Down Expand Up @@ -119,6 +121,12 @@ class RTVIActionRun(BaseModel):
arguments: Optional[List[RTVIActionRunArgument]] = None


@dataclass
class RTVIActionFrame(DataFrame):
rtvi_action_run: RTVIActionRun
message_id: Optional[str] = None


class RTVIMessage(BaseModel):
label: Literal["rtvi-ai"] = "rtvi-ai"
type: str
Expand Down Expand Up @@ -376,6 +384,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self.push_frame(frame, direction)
elif isinstance(frame, TransportMessageFrame):
await self._message_queue.put(frame)
elif isinstance(frame, RTVIActionFrame):
await self._handle_action(frame.message_id, frame.rtvi_action_run)
# Other frames
else:
await self.push_frame(frame, direction)
Expand Down Expand Up @@ -548,7 +558,7 @@ async def _handle_function_call_result(self, data):
)
await self.push_frame(frame)

async def _handle_action(self, request_id: str, data: RTVIActionRun):
async def _handle_action(self, request_id: str | None, data: RTVIActionRun):
action_id = self._action_id(data.service, data.action)
if action_id not in self._registered_actions:
await self._send_error_response(request_id, f"Action {action_id} not registered")
Expand All @@ -559,8 +569,11 @@ async def _handle_action(self, request_id: str, data: RTVIActionRun):
for arg in data.arguments:
arguments[arg.name] = arg.value
result = await action.handler(self, action.service, arguments)
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
await self._push_transport_message(message)
# Only send a response if request_id is present. Things that don't care about
# action responses (such as webhooks) don't set a request_id
if request_id:
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
await self._push_transport_message(message)

async def _maybe_send_bot_ready(self):
if self._pipeline_started and self._client_ready:
Expand Down

0 comments on commit c262b27

Please sign in to comment.