Skip to content

Commit

Permalink
Add basic support for Dall-E image generation and Ollama image recogn…
Browse files Browse the repository at this point in the history
…ition
  • Loading branch information
Hialus committed Feb 11, 2024
1 parent b06d20f commit 4cf118f
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 9 deletions.
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from domain.image import IrisImage
from domain.message import IrisMessage, IrisMessageRole
23 changes: 23 additions & 0 deletions app/domain/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import datetime


class IrisImage:
prompt: str
base64: str
url: str | None
timestamp: datetime
_raw_data: any

def __init__(
self,
prompt: str,
base64: str,
url: str,
timestamp: datetime,
raw_data: any = None,
):
self.prompt = prompt
self.base64 = base64
self.url = url
self.timestamp = timestamp
self._raw_data = raw_data
8 changes: 7 additions & 1 deletion app/domain/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from enum import Enum

from domain import IrisImage


class IrisMessageRole(Enum):
USER = "user"
Expand All @@ -10,10 +12,14 @@ class IrisMessageRole(Enum):
class IrisMessage:
role: IrisMessageRole
text: str
images: list[IrisImage] | None

def __init__(self, role: IrisMessageRole, text: str):
def __init__(
self, role: IrisMessageRole, text: str, images: list[IrisImage] | None = None
):
self.role = role
self.text = text
self.images = images

def __str__(self):
return f"IrisMessage(role={self.role.value}, text='{self.text}')"
4 changes: 2 additions & 2 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper, LlmWrapper
from llm.wrapper import LlmWrapper


# Small workaround to get pydantic discriminators working
Expand All @@ -14,7 +14,7 @@ class LlmList(BaseModel):


class LlmManager(metaclass=Singleton):
entries: list[AbstractLlmWrapper]
entries: list[LlmWrapper]

def __init__(self):
self.entries = []
Expand Down
3 changes: 2 additions & 1 deletion app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper
from llm.wrapper.open_ai_completion_wrapper import (
OpenAICompletionWrapper,
AzureCompletionWrapper,
Expand All @@ -7,6 +6,7 @@
OpenAIChatCompletionWrapper,
AzureChatCompletionWrapper,
)
from llm.wrapper.open_ai_dalle_wrapper import OpenAIDalleWrapper
from llm.wrapper.open_ai_embedding_wrapper import (
OpenAIEmbeddingWrapper,
AzureEmbeddingWrapper,
Expand All @@ -21,4 +21,5 @@
| OpenAIEmbeddingWrapper
| AzureEmbeddingWrapper
| OllamaWrapper
| OpenAIDalleWrapper
)
17 changes: 16 additions & 1 deletion app/llm/wrapper/abstract_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from domain import IrisMessage
from domain import IrisMessage, IrisImage
from llm import CompletionArguments


Expand Down Expand Up @@ -56,3 +56,18 @@ def __subclasshook__(cls, subclass):
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError


class AbstractLlmImageGenerationWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm image generation wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "generate_images") and callable(
subclass.generate_images
)

@abstractmethod
def generate_images(self, prompt: str, n: int, **kwargs) -> list[IrisImage]:
"""Generate images from the prompt"""
raise NotImplementedError
24 changes: 20 additions & 4 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
from typing import Literal, Any

from ollama import Client, Message

from domain import IrisMessage, IrisMessageRole
from domain import IrisMessage, IrisMessageRole, IrisImage
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmChatCompletionWrapper,
Expand All @@ -11,9 +12,20 @@
)


def convert_to_ollama_images(images: list[IrisImage]) -> list[bytes] | None:
if not images:
return None
return [base64.b64decode(image.base64) for image in images]


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
return [
Message(role=message.role.value, content=message.text) for message in messages
Message(
role=message.role.value,
content=message.text,
images=convert_to_ollama_images(message.images),
)
for message in messages
]


Expand All @@ -34,8 +46,12 @@ class OllamaWrapper(
def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
response = self._client.generate(model=self.model, prompt=prompt)
def completion(
self, prompt: str, arguments: CompletionArguments, images: [IrisImage] = None
) -> str:
response = self._client.generate(
model=self.model, prompt=prompt, images=convert_to_ollama_images(images)
)
return response["response"]

def chat_completion(
Expand Down
61 changes: 61 additions & 0 deletions app/llm/wrapper/open_ai_dalle_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import base64
from datetime import datetime
from typing import Literal, Any

import requests
from openai import OpenAI

from domain import IrisImage
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper


class OpenAIDalleWrapper(AbstractLlmWrapper):
type: Literal["openai_dalle"]
model: str
_client: OpenAI

def model_post_init(self, __context: Any) -> None:
self._client = OpenAI(api_key=self.api_key)

def generate_images(
self,
prompt: str,
n: int = 1,
size: Literal[
"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"
] = "256x256",
quality: Literal["standard", "hd"] = "standard",
**kwargs
) -> [IrisImage]:
response = self._client.images.generate(
model=self.model,
prompt=prompt,
size=size,
quality=quality,
n=n,
response_format="url",
**kwargs
)

images = response.data
iris_images = []
for image in images:
if image.revised_prompt is None:
image.revised_prompt = prompt
if image.b64_json is None:
image_response = requests.get(image.url)
image.b64_json = base64.b64encode(image_response.content).decode(
"utf-8"
)

iris_images.append(
IrisImage(
prompt=image.revised_prompt,
base64=image.b64_json,
url=image.url,
timestamp=datetime.fromtimestamp(response.created),
raw_data=image,
)
)

return iris_images
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ black==24.1.1
flake8==7.0.0
pre-commit==3.6.0
pydantic==2.6.1
requests==2.31.0

0 comments on commit 4cf118f

Please sign in to comment.