Skip to content

Commit

Permalink
First draft of LLM subsystem
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Feb 10, 2024
1 parent 281b411 commit 429a4a0
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.12"
cache: 'pip'

- name: Install Dependencies from requirements.txt
Expand Down
1 change: 1 addition & 0 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from singleton import Singleton
7 changes: 7 additions & 0 deletions app/common/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from message import IrisMessage
4 changes: 4 additions & 0 deletions app/domain/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class IrisMessage:
def __init__(self, role, message_text):
self.role = role
self.message_text = message_text
4 changes: 4 additions & 0 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from generation_arguments import CompletionArguments
from request_handler_interface import RequestHandlerInterface
from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
from llm_manager import LlmManager
36 changes: 36 additions & 0 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from domain import IrisMessage
from llm import LlmManager
from llm import RequestHandlerInterface, CompletionArguments
from llm.wrapper import LlmCompletionWrapperInterface, LlmChatCompletionWrapperInterface, LlmEmbeddingWrapperInterface

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmCompletionWrapperInterface):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError

def chat_completion(self, messages: list[IrisMessage], arguments: CompletionArguments) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmChatCompletionWrapperInterface):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmEmbeddingWrapperInterface):
return llm.create_embedding(text)
else:
raise NotImplementedError
7 changes: 7 additions & 0 deletions app/llm/generation_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class CompletionArguments:
"""Arguments for the completion request"""

def __init__(self, max_tokens: int, temperature: float, stop: list[str]):
self.max_tokens = max_tokens
self.temperature = temperature
self.stop = stop
26 changes: 26 additions & 0 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from common import Singleton
from llm.wrapper import LlmWrapperInterface


class LlmManagerEntry:
id: str
llm: LlmWrapperInterface

def __init__(self, id: str, llm: LlmWrapperInterface):
self.id = id
self.llm = llm


class LlmManager(metaclass=Singleton):
llms: list[LlmManagerEntry]

def __init__(self):
self.llms = []

def get_llm_by_id(self, llm_id):
for llm in self.llms:
if llm.id == llm_id:
return llm

def add_llm(self, id: str, llm: LlmWrapperInterface):
self.llms.append(LlmManagerEntry(id, llm))
36 changes: 36 additions & 0 deletions app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm.generation_arguments import CompletionArguments


class RequestHandlerInterface(metaclass=ABCMeta):
"""Interface for the request handlers"""

@classmethod
def __subclasshook__(cls, subclass):
return (
hasattr(subclass, "completion")
and callable(subclass.completion)
and hasattr(subclass, "chat_completion")
and callable(subclass.chat_completion)
and hasattr(subclass, "create_embedding")
and callable(subclass.create_embedding)
)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
"""Create a completion from the chat messages"""
raise NotImplementedError

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
3 changes: 3 additions & 0 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm_wrapper_interface import *
from open_ai_chat_wrapper import *
from ollama_wrapper import OllamaWrapper
32 changes: 32 additions & 0 deletions app/llm/wrapper/azure_chat_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from openai.lib.azure import AzureOpenAI

from llm import CompletionArguments
from llm.wrapper import LlmChatCompletionWrapperInterface, convert_to_open_ai_messages


class AzureChatCompletionWrapper(LlmChatCompletionWrapperInterface):

def __init__(
self,
model: str,
endpoint: str,
azure_deployment: str,
api_version: str,
api_key: str,
):
self.client = AzureOpenAI(
azure_endpoint=endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
)
self.model = model

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
response = self.client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
)
return response
47 changes: 47 additions & 0 deletions app/llm/wrapper/llm_wrapper_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from abc import ABCMeta, abstractmethod

from llm import CompletionArguments

type LlmWrapperInterface = LlmCompletionWrapperInterface | LlmChatCompletionWrapperInterface | LlmEmbeddingWrapperInterface


class LlmCompletionWrapperInterface(metaclass=ABCMeta):
"""Interface for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return (hasattr(subclass, 'completion') and
callable(subclass.completion))

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError


class LlmChatCompletionWrapperInterface(metaclass=ABCMeta):
"""Interface for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return (hasattr(subclass, 'chat_completion') and
callable(subclass.chat_completion))

@abstractmethod
def chat_completion(self, messages: list[any], arguments: CompletionArguments) -> any:
"""Create a completion from the chat messages"""
raise NotImplementedError


class LlmEmbeddingWrapperInterface(metaclass=ABCMeta):
"""Interface for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return (hasattr(subclass, 'create_embedding') and
callable(subclass.create_embedding))

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
48 changes: 48 additions & 0 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from ollama import Client, Message
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam

from domain import IrisMessage
from llm import CompletionArguments
from llm.wrapper import (
LlmChatCompletionWrapperInterface,
LlmCompletionWrapperInterface,
LlmEmbeddingWrapperInterface,
)


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
return [
Message(role=message.role, content=message.message_text) for message in messages
]


def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=message.role, message_text=message.content)


class OllamaWrapper(
LlmCompletionWrapperInterface,
LlmChatCompletionWrapperInterface,
LlmEmbeddingWrapperInterface,
):

def __init__(self, model: str, host: str):
self.client = Client(host=host) # TODO: Add authentication (httpx auth?)
self.model = model

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
response = self.client.generate(model=self.model, prompt=prompt)
return response["response"]

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
response = self.client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
)
return convert_to_iris_message(response["message"])

def create_embedding(self, text: str) -> list[float]:
response = self.client.embeddings(model=self.model, prompt=text)
return response
35 changes: 35 additions & 0 deletions app/llm/wrapper/open_ai_chat_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam

from domain import IrisMessage
from llm import CompletionArguments
from llm.wrapper import LlmChatCompletionWrapperInterface


def convert_to_open_ai_messages(
messages: list[IrisMessage],
) -> list[ChatCompletionMessageParam]:
return [
ChatCompletionMessageParam(role=message.role, content=message.message_text)
for message in messages
]


def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage:
return IrisMessage(role=message.role, message_text=message.content)


class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface):

def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
response = self.client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
)
return response

0 comments on commit 429a4a0

Please sign in to comment.