Skip to content

Commit

Permalink
transports(fastapi_http): make FastAPIHTTPTransport more modular
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed Sep 24, 2024
1 parent 3d95997 commit af13978
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 57 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ daily = [ "daily-python~=0.10.1" ]
deepgram = [ "deepgram-sdk~=3.5.0" ]
elevenlabs = [ "websockets~=12.0" ]
examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ]
http = [ "sse-starlette~=2.1.3" ]
fal = [ "fal-client~=0.4.1" ]
gladia = [ "websockets~=12.0" ]
google = [ "google-generativeai~=0.7.2" ]
Expand Down
113 changes: 56 additions & 57 deletions src/pipecat/transports/network/fastapi_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#


import asyncio
import json
import io
import wave

from typing import Awaitable, Callable
from pydantic.main import BaseModel
from abc import abstractmethod
from typing import AsyncGenerator, Callable

from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, Frame, StartFrame, StartInterruptionFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.frames.frames import AudioRawFrame, EndFrame, Frame, InputAudioRawFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
Expand All @@ -24,12 +20,9 @@

try:
from fastapi import Request, Response
from starlette.background import BackgroundTask
from sse_starlette.sse import EventSourceResponse
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use FastAPI HTTP SSE, you need to `pip install pipecat-ai[http]`.")
logger.error("In order to use FastAPI HTTP SSE, you need to `pip install pipecat-ai[http]`.")
raise Exception(f"Missing module: {e}")


Expand All @@ -38,80 +31,86 @@ class FastAPIHTTPParams(TransportParams):


class FastAPIHTTPInputTransport(BaseInputTransport):

def __init__(
self,
params: FastAPIHTTPParams,
**kwargs):
self,
generator: Callable[[Request], AsyncGenerator[str | bytes, None]],
params: FastAPIHTTPParams,
**kwargs,
):
super().__init__(params, **kwargs)

self._generator = generator
self._params = params
self._request = None

# todo: this should probably expect a list of frames, not just one frame
async def handle_request(self, request: Request):
self._request = request
frames_list = await request.json()
logger.debug(f"Received frames: {frames_list}")
for frame in frames_list:
logger.debug(f"Received frame: {frame}")
frame = self._params.serializer.deserialize(frame)
if frame and isinstance(frame, AudioRawFrame):
await self.push_audio_frame(frame)
async for data in self._generator(request):
frame = self._params.serializer.deserialize(data)
if not frame:
continue

if isinstance(frame, AudioRawFrame):
await self.push_audio_frame(
InputAudioRawFrame(
audio=frame.audio,
sample_rate=frame.sample_rate,
num_channels=frame.num_channels,
)
)
else:
await self.push_frame(frame)


class FastAPIHTTPOutputTransport(BaseOutputTransport):

def __init__(self, params: FastAPIHTTPParams, **kwargs):
super().__init__(params, **kwargs)

self._params = params
self._event_queue = asyncio.Queue()
self._response_queue = asyncio.Queue()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._response_queue.put(None)

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self._write_frame(frame)

async def write_raw_audio_frames(self, frames: bytes):
pass

async def _write_frame(self, frame: Frame):
payload = self._params.serializer.serialize(frame)
await self._event_queue.put(payload)
if payload:
await self._response_queue.put(payload)

async def event_generator(self):
while True:
event = await self._event_queue.get()
logger.debug(f"Sending event {event}")
yield event
async def output_generator(self) -> AsyncGenerator[str | bytes, None]:
running = True
while running:
data = await self._response_queue.get()
running = data is not None
if data:
yield data


class FastAPIHTTPTransport(BaseTransport):

def __init__(
self,
params: FastAPIHTTPParams,
input_name: str | None = None,
output_name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None):
self,
params: FastAPIHTTPParams,
input_name: str | None = None,
output_name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None,
):
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
self._params = params
self._request = None

self._input = FastAPIHTTPInputTransport(
self._params, name=self._input_name)
self._output = FastAPIHTTPOutputTransport(
self._params, name=self._output_name)
generator=self.input_generator, params=self._params, name=self._input_name
)
self._output = FastAPIHTTPOutputTransport(params=self._params, name=self._output_name)

def input(self) -> FrameProcessor:
def input(self) -> FastAPIHTTPInputTransport:
return self._input

def output(self) -> FrameProcessor:
def output(self) -> FastAPIHTTPOutputTransport:
return self._output

async def handle_request(self, request: Request):
self._request = request
await self._input.handle_request(request)
return EventSourceResponse(self._output.event_generator())
@abstractmethod
def input_generator(self, request: Request) -> AsyncGenerator[str | bytes, None]:
pass

@abstractmethod
async def handle_request(self, request: Request) -> Response:
pass

0 comments on commit af13978

Please sign in to comment.