Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: Add basic support for images #53

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from domain.image import IrisImage
from domain.message import IrisMessage, IrisMessageRole
23 changes: 23 additions & 0 deletions app/domain/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import datetime


class IrisImage:
prompt: str
base64: str
url: str | None
timestamp: datetime
_raw_data: any

def __init__(
self,
prompt: str,
base64: str,
url: str,
timestamp: datetime,
raw_data: any = None,
):
self.prompt = prompt
self.base64 = base64
self.url = url
self.timestamp = timestamp
self._raw_data = raw_data
8 changes: 7 additions & 1 deletion app/domain/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from enum import Enum

from domain import IrisImage


class IrisMessageRole(Enum):
USER = "user"
Expand All @@ -10,10 +12,14 @@ class IrisMessageRole(Enum):
class IrisMessage:
role: IrisMessageRole
text: str
images: list[IrisImage] | None

def __init__(self, role: IrisMessageRole, text: str):
def __init__(
self, role: IrisMessageRole, text: str, images: list[IrisImage] | None = None
):
self.role = role
self.text = text
self.images = images

def __str__(self):
return f"IrisMessage(role={self.role.value}, text='{self.text}')"
4 changes: 2 additions & 2 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper, LlmWrapper
from llm.wrapper import LlmWrapper


# Small workaround to get pydantic discriminators working
Expand All @@ -14,7 +14,7 @@ class LlmList(BaseModel):


class LlmManager(metaclass=Singleton):
entries: list[AbstractLlmWrapper]
entries: list[LlmWrapper]

def __init__(self):
self.entries = []
Expand Down
3 changes: 2 additions & 1 deletion app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper
from llm.wrapper.open_ai_completion_wrapper import (
OpenAICompletionWrapper,
AzureCompletionWrapper,
Expand All @@ -7,6 +6,7 @@
OpenAIChatCompletionWrapper,
AzureChatCompletionWrapper,
)
from llm.wrapper.open_ai_dalle_wrapper import OpenAIDalleWrapper
from llm.wrapper.open_ai_embedding_wrapper import (
OpenAIEmbeddingWrapper,
AzureEmbeddingWrapper,
Expand All @@ -21,4 +21,5 @@
| OpenAIEmbeddingWrapper
| AzureEmbeddingWrapper
| OllamaWrapper
| OpenAIDalleWrapper
)
17 changes: 16 additions & 1 deletion app/llm/wrapper/abstract_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from domain import IrisMessage
from domain import IrisMessage, IrisImage
from llm import CompletionArguments


Expand Down Expand Up @@ -56,3 +56,18 @@ def __subclasshook__(cls, subclass):
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError


class AbstractLlmImageGenerationWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm image generation wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "generate_images") and callable(
subclass.generate_images
)

@abstractmethod
def generate_images(self, prompt: str, n: int, **kwargs) -> list[IrisImage]:
"""Generate images from the prompt"""
raise NotImplementedError
24 changes: 20 additions & 4 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
from typing import Literal, Any

from ollama import Client, Message

from domain import IrisMessage, IrisMessageRole
from domain import IrisMessage, IrisMessageRole, IrisImage
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmChatCompletionWrapper,
Expand All @@ -11,9 +12,20 @@
)


def convert_to_ollama_images(images: list[IrisImage]) -> list[bytes] | None:
if not images:
return None
return [base64.b64decode(image.base64) for image in images]


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
return [
Message(role=message.role.value, content=message.text) for message in messages
Message(
role=message.role.value,
content=message.text,
images=convert_to_ollama_images(message.images),
)
for message in messages
]


Expand All @@ -34,8 +46,12 @@ class OllamaWrapper(
def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
response = self._client.generate(model=self.model, prompt=prompt)
def completion(
self, prompt: str, arguments: CompletionArguments, images: [IrisImage] = None
) -> str:
response = self._client.generate(
model=self.model, prompt=prompt, images=convert_to_ollama_images(images)
)
return response["response"]

def chat_completion(
Expand Down
61 changes: 61 additions & 0 deletions app/llm/wrapper/open_ai_dalle_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import base64
from datetime import datetime
from typing import Literal, Any

import requests
from openai import OpenAI

from domain import IrisImage
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper


class OpenAIDalleWrapper(AbstractLlmWrapper):
type: Literal["openai_dalle"]
model: str
_client: OpenAI

def model_post_init(self, __context: Any) -> None:
self._client = OpenAI(api_key=self.api_key)

def generate_images(
self,
prompt: str,
n: int = 1,
size: Literal[
"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"
] = "256x256",
quality: Literal["standard", "hd"] = "standard",
**kwargs
) -> [IrisImage]:
response = self._client.images.generate(
model=self.model,
prompt=prompt,
size=size,
quality=quality,
n=n,
response_format="url",
**kwargs
)

images = response.data
iris_images = []
for image in images:
if image.revised_prompt is None:
image.revised_prompt = prompt
if image.b64_json is None:
image_response = requests.get(image.url)
image.b64_json = base64.b64encode(image_response.content).decode(
"utf-8"
)

iris_images.append(
IrisImage(
prompt=image.revised_prompt,
base64=image.b64_json,
url=image.url,
timestamp=datetime.fromtimestamp(response.created),
raw_data=image,
)
)

return iris_images
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ black==24.1.1
flake8==7.0.0
pre-commit==3.6.0
pydantic==2.6.1
requests==2.31.0
Loading