Skip to content

Commit

Permalink
added params and model attribute to fal service
Browse files Browse the repository at this point in the history
  • Loading branch information
Jon Taylor committed Apr 10, 2024
1 parent 4bd29b0 commit 7b44a79
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/dailyai/services/fal_ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import io
import os
from PIL import Image
from pydantic import BaseModel
from typing import Optional, Union, Dict

from dailyai.services.ai_services import ImageGenService

Expand All @@ -16,30 +18,44 @@


class FalImageGenService(ImageGenService):
class InputParams(BaseModel):
seed: Optional[int] = None
num_inference_steps: int = 4
num_images: int = 1
image_size: Union[str, Dict[str, int]] = "square_hd"
expand_prompt: bool = False
enable_safety_checker: bool = True
format: str = "png"

def __init__(
self,
*,
image_size,
aiohttp_session: aiohttp.ClientSession,
params: InputParams,
model="fal-ai/fast-sdxl",
key_id=None,
key_secret=None
):
super().__init__(image_size)
super().__init__()
self._model = model
self._params = params
self._aiohttp_session = aiohttp_session
if key_id:
os.environ["FAL_KEY_ID"] = key_id
if key_secret:
os.environ["FAL_KEY_SECRET"] = key_secret

async def run_image_gen(self, sentence) -> tuple[str, bytes, tuple[int, int]]:
def get_image_url(sentence, size):
handler = fal.apps.submit(
"110602490-fast-sdxl",
# "fal-ai/fast-sdxl",
arguments={"prompt": sentence},
async def run_image_gen(self, prompt) -> tuple[str, bytes]:
def get_image_url(prompt):
handler = fal.apps.submit( # type: ignore
self._model,
arguments={
"prompt": prompt,
**self._params.dict(),
},
)
for event in handler.iter_events():
if isinstance(event, fal.apps.InProgress):
if isinstance(event, fal.apps.InProgress): # type: ignore
pass

result = handler.get()
Expand All @@ -50,9 +66,10 @@ def get_image_url(sentence, size):

return image_url

image_url = await asyncio.to_thread(get_image_url, sentence, self.image_size)
image_url = await asyncio.to_thread(get_image_url, prompt)

# Load the image from the url
async with self._aiohttp_session.get(image_url) as response:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
return (image_url, image.tobytes(), image.size)
return (image_url, image.tobytes())

0 comments on commit 7b44a79

Please sign in to comment.