diff --git a/src/dailyai/services/fal_ai_services.py b/src/dailyai/services/fal_ai_services.py index 4a7016011..8e6833ee8 100644 --- a/src/dailyai/services/fal_ai_services.py +++ b/src/dailyai/services/fal_ai_services.py @@ -3,6 +3,8 @@ import io import os from PIL import Image +from pydantic import BaseModel +from typing import Optional, Union, Dict from dailyai.services.ai_services import ImageGenService @@ -16,30 +18,44 @@ class FalImageGenService(ImageGenService): + class InputParams(BaseModel): + seed: Optional[int] = None + num_inference_steps: int = 4 + num_images: int = 1 + image_size: Union[str, Dict[str, int]] = "square_hd" + expand_prompt: bool = False + enable_safety_checker: bool = True + format: str = "png" + def __init__( self, *, - image_size, aiohttp_session: aiohttp.ClientSession, + params: InputParams, + model="fal-ai/fast-sdxl", key_id=None, key_secret=None ): - super().__init__(image_size) + super().__init__() + self._model = model + self._params = params self._aiohttp_session = aiohttp_session if key_id: os.environ["FAL_KEY_ID"] = key_id if key_secret: os.environ["FAL_KEY_SECRET"] = key_secret - async def run_image_gen(self, sentence) -> tuple[str, bytes, tuple[int, int]]: - def get_image_url(sentence, size): - handler = fal.apps.submit( - "110602490-fast-sdxl", - # "fal-ai/fast-sdxl", - arguments={"prompt": sentence}, + async def run_image_gen(self, prompt) -> tuple[str, bytes]: + def get_image_url(prompt): + handler = fal.apps.submit( # type: ignore + self._model, + arguments={ + "prompt": prompt, + **self._params.dict(), + }, ) for event in handler.iter_events(): - if isinstance(event, fal.apps.InProgress): + if isinstance(event, fal.apps.InProgress): # type: ignore pass result = handler.get() @@ -50,9 +66,10 @@ def get_image_url(sentence, size): return image_url - image_url = await asyncio.to_thread(get_image_url, sentence, self.image_size) + image_url = await asyncio.to_thread(get_image_url, prompt) + # Load the image from the url async with self._aiohttp_session.get(image_url) as response: image_stream = io.BytesIO(await response.content.read()) image = Image.open(image_stream) - return (image_url, image.tobytes(), image.size) + return (image_url, image.tobytes())