From f315f69c3f25644b9f8475d83148d23a9df1e58f Mon Sep 17 00:00:00 2001 From: Chad Bailey Date: Fri, 12 Apr 2024 03:22:05 +0000 Subject: [PATCH] Added OpenAI and Fireworks vision --- examples/foundational/12a-fireworks-describe-video.py | 8 ++++++-- src/dailyai/services/fireworks_ai_services.py | 8 +++++--- src/dailyai/services/open_ai_services.py | 7 ++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/foundational/12a-fireworks-describe-video.py b/examples/foundational/12a-fireworks-describe-video.py index 14e3b6d42..4eee71384 100644 --- a/examples/foundational/12a-fireworks-describe-video.py +++ b/examples/foundational/12a-fireworks-describe-video.py @@ -11,6 +11,7 @@ from dailyai.pipeline.pipeline import Pipeline from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from dailyai.services.open_ai_services import OpenAIVisionService +from dailyai.services.fireworks_ai_services import FireworksVisionService from dailyai.transports.daily_transport import DailyTransport from runner import configure @@ -62,8 +63,11 @@ async def main(room_url: str, token): vision_aggregator = VisionImageFrameAggregator() # If you run into weird description, try with use_cpu=True - img_desc = OpenAIVisionService( - api_key=os.getenv("OPENAI_API_KEY") + # img_desc = OpenAIVisionService( + # api_key=os.getenv("OPENAI_API_KEY") + # ) + img_desc = FireworksVisionService( + api_key=os.getenv("FIREWORKS_API_KEY") ) tts = ElevenLabsTTSService( diff --git a/src/dailyai/services/fireworks_ai_services.py b/src/dailyai/services/fireworks_ai_services.py index a2dcc4462..2a9718441 100644 --- a/src/dailyai/services/fireworks_ai_services.py +++ b/src/dailyai/services/fireworks_ai_services.py @@ -5,7 +5,7 @@ from dailyai.services.ai_services import ImageGenService, VisionService from dailyai.services.openai_api_llm_service import BaseOpenAILLMService - +from dailyai.services.open_ai_services import OpenAIVisionService try: from openai import AsyncOpenAI @@ -59,5 +59,7 @@ async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]] image = Image.open(image_stream) return (image_url, image.tobytes(), image.size) -class FireworksVisionService(VisionService): - \ No newline at end of file + +class FireworksVisionService(OpenAIVisionService): + def __init__(self, *, api_key, model="accounts/fireworks/models/firellava-13b"): + super().__init__(model=model, api_key=api_key, base_url="https://api.fireworks.ai/inference/v1") diff --git a/src/dailyai/services/open_ai_services.py b/src/dailyai/services/open_ai_services.py index 4683770b1..bb6103606 100644 --- a/src/dailyai/services/open_ai_services.py +++ b/src/dailyai/services/open_ai_services.py @@ -72,9 +72,14 @@ def __init__( *, model="gpt-4-vision-preview", api_key, + base_url=None, ): self._model = model - self._client = AsyncOpenAI(api_key=api_key) + if base_url: + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + else: + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + super().__init__() async def run_vision_async(self, frame):