Skip to content

Commit

Permalink
Added OpenAI and Fireworks vision
Browse files Browse the repository at this point in the history
  • Loading branch information
chadbailey59 committed Apr 12, 2024
1 parent 7d49391 commit 6c53402
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
8 changes: 6 additions & 2 deletions examples/foundational/12a-fireworks-describe-video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dailyai.pipeline.pipeline import Pipeline
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.open_ai_services import OpenAIVisionService
from dailyai.services.fireworks_ai_services import FireworksVisionService
from dailyai.transports.daily_transport import DailyTransport

from runner import configure
Expand Down Expand Up @@ -62,8 +63,11 @@ async def main(room_url: str, token):
vision_aggregator = VisionImageFrameAggregator()

# If you run into weird description, try with use_cpu=True
img_desc = OpenAIVisionService(
api_key=os.getenv("OPENAI_API_KEY")
# img_desc = OpenAIVisionService(
# api_key=os.getenv("OPENAI_API_KEY")
# )
img_desc = FireworksVisionService(
api_key=os.getenv("FIREWORKS_API_KEY")
)

tts = ElevenLabsTTSService(
Expand Down
8 changes: 5 additions & 3 deletions src/dailyai/services/fireworks_ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dailyai.services.ai_services import ImageGenService, VisionService
from dailyai.services.openai_api_llm_service import BaseOpenAILLMService

from dailyai.services.open_ai_services import OpenAIVisionService

try:
from openai import AsyncOpenAI
Expand Down Expand Up @@ -59,5 +59,7 @@ async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]]
image = Image.open(image_stream)
return (image_url, image.tobytes(), image.size)

class FireworksVisionService(VisionService):


class FireworksVisionService(OpenAIVisionService):
def __init__(self, *, api_key, model="accounts/fireworks/models/firellava-13b"):
super().__init__(model=model, api_key=api_key, base_url="https://api.fireworks.ai/inference/v1")
7 changes: 6 additions & 1 deletion src/dailyai/services/open_ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ def __init__(
*,
model="gpt-4-vision-preview",
api_key,
base_url=None,
):
self._model = model
self._client = AsyncOpenAI(api_key=api_key)
if base_url:
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)

super().__init__()

async def run_vision_async(self, frame):
Expand Down

0 comments on commit 6c53402

Please sign in to comment.