From af139786586b7e4095f44121fb4ef339e1260c91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 23 Sep 2024 23:38:01 -0700 Subject: [PATCH] transports(fastapi_http): make FastAPIHTTPTransport more modular --- pyproject.toml | 1 + .../transports/network/fastapi_http.py | 113 +++++++++--------- 2 files changed, 57 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aebccda2f..852a0d2ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ daily = [ "daily-python~=0.10.1" ] deepgram = [ "deepgram-sdk~=3.5.0" ] elevenlabs = [ "websockets~=12.0" ] examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ] +http = [ "sse-starlette~=2.1.3" ] fal = [ "fal-client~=0.4.1" ] gladia = [ "websockets~=12.0" ] google = [ "google-generativeai~=0.7.2" ] diff --git a/src/pipecat/transports/network/fastapi_http.py b/src/pipecat/transports/network/fastapi_http.py index 9c58fa0d2..b9284e25d 100644 --- a/src/pipecat/transports/network/fastapi_http.py +++ b/src/pipecat/transports/network/fastapi_http.py @@ -4,17 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # - import asyncio -import json -import io -import wave -from typing import Awaitable, Callable -from pydantic.main import BaseModel +from abc import abstractmethod +from typing import AsyncGenerator, Callable -from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, Frame, StartFrame, StartInterruptionFrame -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import AudioRawFrame, EndFrame, Frame, InputAudioRawFrame +from pipecat.processors.frame_processor import FrameDirection from pipecat.serializers.base_serializer import FrameSerializer from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport @@ -24,12 +20,9 @@ try: from fastapi import Request, Response - from starlette.background import BackgroundTask - from sse_starlette.sse import EventSourceResponse except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use FastAPI HTTP SSE, you need to `pip install pipecat-ai[http]`.") + logger.error("In order to use FastAPI HTTP SSE, you need to `pip install pipecat-ai[http]`.") raise Exception(f"Missing module: {e}") @@ -38,80 +31,86 @@ class FastAPIHTTPParams(TransportParams): class FastAPIHTTPInputTransport(BaseInputTransport): - def __init__( - self, - params: FastAPIHTTPParams, - **kwargs): + self, + generator: Callable[[Request], AsyncGenerator[str | bytes, None]], + params: FastAPIHTTPParams, + **kwargs, + ): super().__init__(params, **kwargs) - + self._generator = generator self._params = params - self._request = None - # todo: this should probably expect a list of frames, not just one frame async def handle_request(self, request: Request): - self._request = request - frames_list = await request.json() - logger.debug(f"Received frames: {frames_list}") - for frame in frames_list: - logger.debug(f"Received frame: {frame}") - frame = self._params.serializer.deserialize(frame) - if frame and isinstance(frame, AudioRawFrame): - await self.push_audio_frame(frame) + async for data in self._generator(request): + frame = self._params.serializer.deserialize(data) + if not frame: + continue + + if isinstance(frame, AudioRawFrame): + await self.push_audio_frame( + InputAudioRawFrame( + audio=frame.audio, + sample_rate=frame.sample_rate, + num_channels=frame.num_channels, + ) + ) else: await self.push_frame(frame) class FastAPIHTTPOutputTransport(BaseOutputTransport): - def __init__(self, params: FastAPIHTTPParams, **kwargs): super().__init__(params, **kwargs) self._params = params - self._event_queue = asyncio.Queue() + self._response_queue = asyncio.Queue() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._response_queue.put(None) async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - await self._write_frame(frame) - - async def write_raw_audio_frames(self, frames: bytes): - pass - - async def _write_frame(self, frame: Frame): payload = self._params.serializer.serialize(frame) - await self._event_queue.put(payload) + if payload: + await self._response_queue.put(payload) - async def event_generator(self): - while True: - event = await self._event_queue.get() - logger.debug(f"Sending event {event}") - yield event + async def output_generator(self) -> AsyncGenerator[str | bytes, None]: + running = True + while running: + data = await self._response_queue.get() + running = data is not None + if data: + yield data class FastAPIHTTPTransport(BaseTransport): - def __init__( - self, - params: FastAPIHTTPParams, - input_name: str | None = None, - output_name: str | None = None, - loop: asyncio.AbstractEventLoop | None = None): + self, + params: FastAPIHTTPParams, + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ): super().__init__(input_name=input_name, output_name=output_name, loop=loop) self._params = params - self._request = None self._input = FastAPIHTTPInputTransport( - self._params, name=self._input_name) - self._output = FastAPIHTTPOutputTransport( - self._params, name=self._output_name) + generator=self.input_generator, params=self._params, name=self._input_name + ) + self._output = FastAPIHTTPOutputTransport(params=self._params, name=self._output_name) - def input(self) -> FrameProcessor: + def input(self) -> FastAPIHTTPInputTransport: return self._input - def output(self) -> FrameProcessor: + def output(self) -> FastAPIHTTPOutputTransport: return self._output - async def handle_request(self, request: Request): - self._request = request - await self._input.handle_request(request) - return EventSourceResponse(self._output.event_generator()) + @abstractmethod + def input_generator(self, request: Request) -> AsyncGenerator[str | bytes, None]: + pass + + @abstractmethod + async def handle_request(self, request: Request) -> Response: + pass