-
Notifications
You must be signed in to change notification settings - Fork 435
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2ecec1c
commit 3d95997
Showing
1 changed file
with
117 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# | ||
# Copyright (c) 2024, Daily | ||
# | ||
# SPDX-License-Identifier: BSD 2-Clause License | ||
# | ||
|
||
|
||
import asyncio | ||
import json | ||
import io | ||
import wave | ||
|
||
from typing import Awaitable, Callable | ||
from pydantic.main import BaseModel | ||
|
||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, Frame, StartFrame, StartInterruptionFrame | ||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor | ||
from pipecat.serializers.base_serializer import FrameSerializer | ||
from pipecat.transports.base_input import BaseInputTransport | ||
from pipecat.transports.base_output import BaseOutputTransport | ||
from pipecat.transports.base_transport import BaseTransport, TransportParams | ||
|
||
from loguru import logger | ||
|
||
try: | ||
from fastapi import Request, Response | ||
from starlette.background import BackgroundTask | ||
from sse_starlette.sse import EventSourceResponse | ||
except ModuleNotFoundError as e: | ||
logger.error(f"Exception: {e}") | ||
logger.error( | ||
"In order to use FastAPI HTTP SSE, you need to `pip install pipecat-ai[http]`.") | ||
raise Exception(f"Missing module: {e}") | ||
|
||
|
||
class FastAPIHTTPParams(TransportParams): | ||
serializer: FrameSerializer | ||
|
||
|
||
class FastAPIHTTPInputTransport(BaseInputTransport): | ||
|
||
def __init__( | ||
self, | ||
params: FastAPIHTTPParams, | ||
**kwargs): | ||
super().__init__(params, **kwargs) | ||
|
||
self._params = params | ||
self._request = None | ||
|
||
# todo: this should probably expect a list of frames, not just one frame | ||
async def handle_request(self, request: Request): | ||
self._request = request | ||
frames_list = await request.json() | ||
logger.debug(f"Received frames: {frames_list}") | ||
for frame in frames_list: | ||
logger.debug(f"Received frame: {frame}") | ||
frame = self._params.serializer.deserialize(frame) | ||
if frame and isinstance(frame, AudioRawFrame): | ||
await self.push_audio_frame(frame) | ||
else: | ||
await self.push_frame(frame) | ||
|
||
|
||
class FastAPIHTTPOutputTransport(BaseOutputTransport): | ||
|
||
def __init__(self, params: FastAPIHTTPParams, **kwargs): | ||
super().__init__(params, **kwargs) | ||
|
||
self._params = params | ||
self._event_queue = asyncio.Queue() | ||
|
||
async def process_frame(self, frame: Frame, direction: FrameDirection): | ||
await super().process_frame(frame, direction) | ||
await self._write_frame(frame) | ||
|
||
async def write_raw_audio_frames(self, frames: bytes): | ||
pass | ||
|
||
async def _write_frame(self, frame: Frame): | ||
payload = self._params.serializer.serialize(frame) | ||
await self._event_queue.put(payload) | ||
|
||
async def event_generator(self): | ||
while True: | ||
event = await self._event_queue.get() | ||
logger.debug(f"Sending event {event}") | ||
yield event | ||
|
||
|
||
class FastAPIHTTPTransport(BaseTransport): | ||
|
||
def __init__( | ||
self, | ||
params: FastAPIHTTPParams, | ||
input_name: str | None = None, | ||
output_name: str | None = None, | ||
loop: asyncio.AbstractEventLoop | None = None): | ||
super().__init__(input_name=input_name, output_name=output_name, loop=loop) | ||
self._params = params | ||
self._request = None | ||
|
||
self._input = FastAPIHTTPInputTransport( | ||
self._params, name=self._input_name) | ||
self._output = FastAPIHTTPOutputTransport( | ||
self._params, name=self._output_name) | ||
|
||
def input(self) -> FrameProcessor: | ||
return self._input | ||
|
||
def output(self) -> FrameProcessor: | ||
return self._output | ||
|
||
async def handle_request(self, request: Request): | ||
self._request = request | ||
await self._input.handle_request(request) | ||
return EventSourceResponse(self._output.event_generator()) |