diff --git a/examples/foundational/03-still-frame.py b/examples/foundational/03-still-frame.py index cd89a041d..3e371da44 100644 --- a/examples/foundational/03-still-frame.py +++ b/examples/foundational/03-still-frame.py @@ -7,7 +7,6 @@ from dailyai.pipeline.pipeline import Pipeline from dailyai.transports.daily_transport import DailyTransport from dailyai.services.fal_ai_services import FalImageGenService -from dailyai.services.fireworks_ai_services import FireworksImageGenService from runner import configure @@ -31,20 +30,14 @@ async def main(room_url): duration_minutes=1 ) - # imagegen = FalImageGenService( - # params=FalImageGenService.InputParams( - # image_size="square_hd" - # ), - # aiohttp_session=session, - # key_id=os.getenv("FAL_KEY_ID"), - # key_secret=os.getenv("FAL_KEY_SECRET"), - # ) - - imagegen = FireworksImageGenService( + imagegen = FalImageGenService( + params=FalImageGenService.InputParams( + image_size="square_hd" + ), aiohttp_session=session, - api_key=os.getenv("FIREWORKS_API_KEY"), - model="accounts/fireworks/models/stable-diffusion-xl-1024-v1-0", - image_size="1024x1024") + key_id=os.getenv("FAL_KEY_ID"), + key_secret=os.getenv("FAL_KEY_SECRET"), + ) pipeline = Pipeline([imagegen]) diff --git a/src/dailyai/services/fireworks_ai_services.py b/src/dailyai/services/fireworks_ai_services.py index 2a9718441..df02d5605 100644 --- a/src/dailyai/services/fireworks_ai_services.py +++ b/src/dailyai/services/fireworks_ai_services.py @@ -22,43 +22,6 @@ def __init__(self, model="accounts/fireworks/models/firefunction-v1", *args, **k super().__init__(model, *args, **kwargs) -class FireworksImageGenService(ImageGenService): - - def __init__( - self, - *, - image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], - aiohttp_session: aiohttp.ClientSession, - api_key, - model="accounts/fireworks/models/stable-diffusion-xl-1024-v1-0", - ): - super().__init__() - self._model = model - self._image_size = image_size - self._client = AsyncOpenAI(api_key=api_key, - base_url="https://api.fireworks.ai/inference/v1") - self._aiohttp_session = aiohttp_session - - async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]]: - self.logger.info(f"Generating Fireworks image: {prompt}") - - image = await self._client.images.generate( - prompt=prompt, - model=self._model, - n=1, - size=self._image_size - ) - print(f"!!! image is {image}") - image_url = image.data[0].url - if not image_url: - raise Exception("No image provided in response", image) - - # Load the image from the url - async with self._aiohttp_session.get(image_url) as response: - image_stream = io.BytesIO(await response.content.read()) - image = Image.open(image_stream) - return (image_url, image.tobytes(), image.size) - class FireworksVisionService(OpenAIVisionService): def __init__(self, *, api_key, model="accounts/fireworks/models/firellava-13b"):