diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 17c81c08..af785312 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" cache: 'pip' - name: Install Dependencies from requirements.txt diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 00000000..e190c8ba --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1 @@ +from singleton import Singleton diff --git a/app/common/singleton.py b/app/common/singleton.py new file mode 100644 index 00000000..3776cb92 --- /dev/null +++ b/app/common/singleton.py @@ -0,0 +1,7 @@ +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 00000000..5ead9e65 --- /dev/null +++ b/app/domain/__init__.py @@ -0,0 +1 @@ +from message import IrisMessage diff --git a/app/domain/message.py b/app/domain/message.py new file mode 100644 index 00000000..b5fe26f7 --- /dev/null +++ b/app/domain/message.py @@ -0,0 +1,4 @@ +class IrisMessage: + def __init__(self, role, message_text): + self.role = role + self.message_text = message_text diff --git a/app/llm/__init__.py b/app/llm/__init__.py new file mode 100644 index 00000000..73b39dc6 --- /dev/null +++ b/app/llm/__init__.py @@ -0,0 +1,4 @@ +from generation_arguments import CompletionArguments +from request_handler_interface import RequestHandlerInterface +from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel +from llm_manager import LlmManager diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py new file mode 100644 index 00000000..0c284002 --- /dev/null +++ b/app/llm/basic_request_handler.py @@ -0,0 +1,36 @@ +from domain import IrisMessage +from llm import LlmManager +from llm import RequestHandlerInterface, CompletionArguments +from llm.wrapper import LlmCompletionWrapperInterface, LlmChatCompletionWrapperInterface, LlmEmbeddingWrapperInterface + +type BasicRequestHandlerModel = str + + +class BasicRequestHandler(RequestHandlerInterface): + model: BasicRequestHandlerModel + llm_manager: LlmManager + + def __init__(self, model: BasicRequestHandlerModel): + self.model = model + self.llm_manager = LlmManager() + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, LlmCompletionWrapperInterface): + return llm.completion(prompt, arguments) + else: + raise NotImplementedError + + def chat_completion(self, messages: list[IrisMessage], arguments: CompletionArguments) -> IrisMessage: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, LlmChatCompletionWrapperInterface): + return llm.chat_completion(messages, arguments) + else: + raise NotImplementedError + + def create_embedding(self, text: str) -> list[float]: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, LlmEmbeddingWrapperInterface): + return llm.create_embedding(text) + else: + raise NotImplementedError diff --git a/app/llm/generation_arguments.py b/app/llm/generation_arguments.py new file mode 100644 index 00000000..37a4af19 --- /dev/null +++ b/app/llm/generation_arguments.py @@ -0,0 +1,7 @@ +class CompletionArguments: + """Arguments for the completion request""" + + def __init__(self, max_tokens: int, temperature: float, stop: list[str]): + self.max_tokens = max_tokens + self.temperature = temperature + self.stop = stop diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py new file mode 100644 index 00000000..8dd649d8 --- /dev/null +++ b/app/llm/llm_manager.py @@ -0,0 +1,26 @@ +from common import Singleton +from llm.wrapper import LlmWrapperInterface + + +class LlmManagerEntry: + id: str + llm: LlmWrapperInterface + + def __init__(self, id: str, llm: LlmWrapperInterface): + self.id = id + self.llm = llm + + +class LlmManager(metaclass=Singleton): + llms: list[LlmManagerEntry] + + def __init__(self): + self.llms = [] + + def get_llm_by_id(self, llm_id): + for llm in self.llms: + if llm.id == llm_id: + return llm + + def add_llm(self, id: str, llm: LlmWrapperInterface): + self.llms.append(LlmManagerEntry(id, llm)) diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py new file mode 100644 index 00000000..5c15df30 --- /dev/null +++ b/app/llm/request_handler_interface.py @@ -0,0 +1,36 @@ +from abc import ABCMeta, abstractmethod + +from domain import IrisMessage +from llm.generation_arguments import CompletionArguments + + +class RequestHandlerInterface(metaclass=ABCMeta): + """Interface for the request handlers""" + + @classmethod + def __subclasshook__(cls, subclass): + return ( + hasattr(subclass, "completion") + and callable(subclass.completion) + and hasattr(subclass, "chat_completion") + and callable(subclass.chat_completion) + and hasattr(subclass, "create_embedding") + and callable(subclass.create_embedding) + ) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + @abstractmethod + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> [IrisMessage]: + """Create a completion from the chat messages""" + raise NotImplementedError + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py new file mode 100644 index 00000000..6aef2481 --- /dev/null +++ b/app/llm/wrapper/__init__.py @@ -0,0 +1,3 @@ +from llm_wrapper_interface import * +from open_ai_chat_wrapper import * +from ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/azure_chat_wrapper.py b/app/llm/wrapper/azure_chat_wrapper.py new file mode 100644 index 00000000..9022e3ef --- /dev/null +++ b/app/llm/wrapper/azure_chat_wrapper.py @@ -0,0 +1,32 @@ +from openai.lib.azure import AzureOpenAI + +from llm import CompletionArguments +from llm.wrapper import LlmChatCompletionWrapperInterface, convert_to_open_ai_messages + + +class AzureChatCompletionWrapper(LlmChatCompletionWrapperInterface): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + self.client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + self.model = model + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat.completions.create( + model=self.model, + messages=convert_to_open_ai_messages(messages), + ) + return response diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/llm_wrapper_interface.py new file mode 100644 index 00000000..e94f3e3b --- /dev/null +++ b/app/llm/wrapper/llm_wrapper_interface.py @@ -0,0 +1,47 @@ +from abc import ABCMeta, abstractmethod + +from llm import CompletionArguments + +type LlmWrapperInterface = LlmCompletionWrapperInterface | LlmChatCompletionWrapperInterface | LlmEmbeddingWrapperInterface + + +class LlmCompletionWrapperInterface(metaclass=ABCMeta): + """Interface for the llm completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return (hasattr(subclass, 'completion') and + callable(subclass.completion)) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + +class LlmChatCompletionWrapperInterface(metaclass=ABCMeta): + """Interface for the llm chat completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return (hasattr(subclass, 'chat_completion') and + callable(subclass.chat_completion)) + + @abstractmethod + def chat_completion(self, messages: list[any], arguments: CompletionArguments) -> any: + """Create a completion from the chat messages""" + raise NotImplementedError + + +class LlmEmbeddingWrapperInterface(metaclass=ABCMeta): + """Interface for the llm embedding wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return (hasattr(subclass, 'create_embedding') and + callable(subclass.create_embedding)) + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py new file mode 100644 index 00000000..d4ca69bb --- /dev/null +++ b/app/llm/wrapper/ollama_wrapper.py @@ -0,0 +1,48 @@ +from ollama import Client, Message +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from domain import IrisMessage +from llm import CompletionArguments +from llm.wrapper import ( + LlmChatCompletionWrapperInterface, + LlmCompletionWrapperInterface, + LlmEmbeddingWrapperInterface, +) + + +def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: + return [ + Message(role=message.role, content=message.message_text) for message in messages + ] + + +def convert_to_iris_message(message: Message) -> IrisMessage: + return IrisMessage(role=message.role, message_text=message.content) + + +class OllamaWrapper( + LlmCompletionWrapperInterface, + LlmChatCompletionWrapperInterface, + LlmEmbeddingWrapperInterface, +): + + def __init__(self, model: str, host: str): + self.client = Client(host=host) # TODO: Add authentication (httpx auth?) + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + response = self.client.generate(model=self.model, prompt=prompt) + return response["response"] + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat( + model=self.model, messages=convert_to_ollama_messages(messages) + ) + return convert_to_iris_message(response["message"]) + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings(model=self.model, prompt=text) + return response diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py new file mode 100644 index 00000000..db04de18 --- /dev/null +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -0,0 +1,35 @@ +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from domain import IrisMessage +from llm import CompletionArguments +from llm.wrapper import LlmChatCompletionWrapperInterface + + +def convert_to_open_ai_messages( + messages: list[IrisMessage], +) -> list[ChatCompletionMessageParam]: + return [ + ChatCompletionMessageParam(role=message.role, content=message.message_text) + for message in messages + ] + + +def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: + return IrisMessage(role=message.role, message_text=message.content) + + +class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): + + def __init__(self, model: str, api_key: str): + self.client = OpenAI(api_key=api_key) + self.model = model + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat.completions.create( + model=self.model, + messages=convert_to_open_ai_messages(messages), + ) + return response