diff --git a/app/domain/__init__.py b/app/domain/__init__.py index b73080e7..9074528d 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -1 +1,2 @@ +from domain.image import IrisImage from domain.message import IrisMessage, IrisMessageRole diff --git a/app/domain/image.py b/app/domain/image.py new file mode 100644 index 00000000..2fd0175d --- /dev/null +++ b/app/domain/image.py @@ -0,0 +1,23 @@ +from datetime import datetime + + +class IrisImage: + prompt: str + base64: str + url: str | None + timestamp: datetime + _raw_data: any + + def __init__( + self, + prompt: str, + base64: str, + url: str, + timestamp: datetime, + raw_data: any = None, + ): + self.prompt = prompt + self.base64 = base64 + self.url = url + self.timestamp = timestamp + self._raw_data = raw_data diff --git a/app/domain/message.py b/app/domain/message.py index b1f521cc..d33347ac 100644 --- a/app/domain/message.py +++ b/app/domain/message.py @@ -1,5 +1,7 @@ from enum import Enum +from domain import IrisImage + class IrisMessageRole(Enum): USER = "user" @@ -10,10 +12,14 @@ class IrisMessageRole(Enum): class IrisMessage: role: IrisMessageRole text: str + images: list[IrisImage] | None - def __init__(self, role: IrisMessageRole, text: str): + def __init__( + self, role: IrisMessageRole, text: str, images: list[IrisImage] | None = None + ): self.role = role self.text = text + self.images = images def __str__(self): return f"IrisMessage(role={self.role.value}, text='{self.text}')" diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index af593d32..9769cf2c 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -5,7 +5,7 @@ import yaml from common import Singleton -from llm.wrapper import AbstractLlmWrapper, LlmWrapper +from llm.wrapper import LlmWrapper # Small workaround to get pydantic discriminators working @@ -14,7 +14,7 @@ class LlmList(BaseModel): class LlmManager(metaclass=Singleton): - entries: list[AbstractLlmWrapper] + entries: list[LlmWrapper] def __init__(self): self.entries = [] diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index c4807ec5..7cd2184c 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,4 +1,3 @@ -from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper from llm.wrapper.open_ai_completion_wrapper import ( OpenAICompletionWrapper, AzureCompletionWrapper, @@ -7,6 +6,7 @@ OpenAIChatCompletionWrapper, AzureChatCompletionWrapper, ) +from llm.wrapper.open_ai_dalle_wrapper import OpenAIDalleWrapper from llm.wrapper.open_ai_embedding_wrapper import ( OpenAIEmbeddingWrapper, AzureEmbeddingWrapper, @@ -21,4 +21,5 @@ | OpenAIEmbeddingWrapper | AzureEmbeddingWrapper | OllamaWrapper + | OpenAIDalleWrapper ) diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py index 057b3aca..9386f828 100644 --- a/app/llm/wrapper/abstract_llm_wrapper.py +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from pydantic import BaseModel -from domain import IrisMessage +from domain import IrisMessage, IrisImage from llm import CompletionArguments @@ -56,3 +56,18 @@ def __subclasshook__(cls, subclass): def create_embedding(self, text: str) -> list[float]: """Create an embedding from the text""" raise NotImplementedError + + +class AbstractLlmImageGenerationWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm image generation wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "generate_images") and callable( + subclass.generate_images + ) + + @abstractmethod + def generate_images(self, prompt: str, n: int, **kwargs) -> list[IrisImage]: + """Generate images from the prompt""" + raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 4ea0e9b0..1938bcc8 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,8 +1,9 @@ +import base64 from typing import Literal, Any from ollama import Client, Message -from domain import IrisMessage, IrisMessageRole +from domain import IrisMessage, IrisMessageRole, IrisImage from llm import CompletionArguments from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmChatCompletionWrapper, @@ -11,9 +12,20 @@ ) +def convert_to_ollama_images(images: list[IrisImage]) -> list[bytes] | None: + if not images: + return None + return [base64.b64decode(image.base64) for image in images] + + def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: return [ - Message(role=message.role.value, content=message.text) for message in messages + Message( + role=message.role.value, + content=message.text, + images=convert_to_ollama_images(message.images), + ) + for message in messages ] @@ -34,8 +46,12 @@ class OllamaWrapper( def model_post_init(self, __context: Any) -> None: self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?) - def completion(self, prompt: str, arguments: CompletionArguments) -> str: - response = self._client.generate(model=self.model, prompt=prompt) + def completion( + self, prompt: str, arguments: CompletionArguments, images: [IrisImage] = None + ) -> str: + response = self._client.generate( + model=self.model, prompt=prompt, images=convert_to_ollama_images(images) + ) return response["response"] def chat_completion( diff --git a/app/llm/wrapper/open_ai_dalle_wrapper.py b/app/llm/wrapper/open_ai_dalle_wrapper.py new file mode 100644 index 00000000..2ceda456 --- /dev/null +++ b/app/llm/wrapper/open_ai_dalle_wrapper.py @@ -0,0 +1,61 @@ +import base64 +from datetime import datetime +from typing import Literal, Any + +import requests +from openai import OpenAI + +from domain import IrisImage +from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper + + +class OpenAIDalleWrapper(AbstractLlmWrapper): + type: Literal["openai_dalle"] + model: str + _client: OpenAI + + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) + + def generate_images( + self, + prompt: str, + n: int = 1, + size: Literal[ + "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792" + ] = "256x256", + quality: Literal["standard", "hd"] = "standard", + **kwargs + ) -> [IrisImage]: + response = self._client.images.generate( + model=self.model, + prompt=prompt, + size=size, + quality=quality, + n=n, + response_format="url", + **kwargs + ) + + images = response.data + iris_images = [] + for image in images: + if image.revised_prompt is None: + image.revised_prompt = prompt + if image.b64_json is None: + image_response = requests.get(image.url) + image.b64_json = base64.b64encode(image_response.content).decode( + "utf-8" + ) + + iris_images.append( + IrisImage( + prompt=image.revised_prompt, + base64=image.b64_json, + url=image.url, + timestamp=datetime.fromtimestamp(response.created), + raw_data=image, + ) + ) + + return iris_images diff --git a/requirements.txt b/requirements.txt index 71c9e37e..17529333 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 pydantic==2.6.1 +requests==2.31.0