diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..a5c48807 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +max-line-length = 120 +exclude = + .git, + __pycache__, + .idea +per-file-ignores = + # imported but unused + __init__.py: F401, F403 + open_ai_chat_wrapper.py: F811 + open_ai_completion_wrapper.py: F811 + open_ai_embedding_wrapper.py: F811 + diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000..62d31784 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,3 @@ +"component:LLM": + - changed-files: + - any-glob-to-any-file: app/llm/** diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 17c81c08..af785312 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" cache: 'pip' - name: Install Dependencies from requirements.txt diff --git a/.github/workflows/pullrequest-labeler.yml b/.github/workflows/pullrequest-labeler.yml new file mode 100644 index 00000000..f7739956 --- /dev/null +++ b/.github/workflows/pullrequest-labeler.yml @@ -0,0 +1,10 @@ +name: Pull Request Labeler +on: pull_request_target + +jobs: + label: + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v5 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42fd914f..b47d541b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,8 +5,9 @@ rev: stable hooks: - id: black - language_version: python3.11 + language_version: python3.12 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.0.0 hooks: - - id: flake8 \ No newline at end of file + - id: flake8 + language_version: python3.12 \ No newline at end of file diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 00000000..97e30c68 --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1 @@ +from common.singleton import Singleton diff --git a/app/common/singleton.py b/app/common/singleton.py new file mode 100644 index 00000000..3776cb92 --- /dev/null +++ b/app/common/singleton.py @@ -0,0 +1,7 @@ +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 00000000..b73080e7 --- /dev/null +++ b/app/domain/__init__.py @@ -0,0 +1 @@ +from domain.message import IrisMessage, IrisMessageRole diff --git a/app/domain/message.py b/app/domain/message.py new file mode 100644 index 00000000..b1f521cc --- /dev/null +++ b/app/domain/message.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class IrisMessageRole(Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +class IrisMessage: + role: IrisMessageRole + text: str + + def __init__(self, role: IrisMessageRole, text: str): + self.role = role + self.text = text + + def __str__(self): + return f"IrisMessage(role={self.role.value}, text='{self.text}')" diff --git a/app/llm/__init__.py b/app/llm/__init__.py new file mode 100644 index 00000000..33542f1c --- /dev/null +++ b/app/llm/__init__.py @@ -0,0 +1,3 @@ +from llm.request_handler_interface import RequestHandlerInterface +from llm.generation_arguments import * +from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py new file mode 100644 index 00000000..f348da1f --- /dev/null +++ b/app/llm/basic_request_handler.py @@ -0,0 +1,48 @@ +from domain import IrisMessage +from llm import RequestHandlerInterface, CompletionArguments +from llm.llm_manager import LlmManager +from llm.wrapper import ( + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, +) + +type BasicRequestHandlerModel = str + + +class BasicRequestHandler(RequestHandlerInterface): + model: BasicRequestHandlerModel + llm_manager: LlmManager + + def __init__(self, model: BasicRequestHandlerModel): + self.model = model + self.llm_manager = LlmManager() + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, AbstractLlmCompletionWrapper): + return llm.completion(prompt, arguments) + else: + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support completion" + ) + + def chat_completion( + self, messages: list[IrisMessage], arguments: CompletionArguments + ) -> IrisMessage: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, AbstractLlmChatCompletionWrapper): + return llm.chat_completion(messages, arguments) + else: + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support chat completion" + ) + + def create_embedding(self, text: str) -> list[float]: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, AbstractLlmEmbeddingWrapper): + return llm.create_embedding(text) + else: + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support embedding" + ) diff --git a/app/llm/generation_arguments.py b/app/llm/generation_arguments.py new file mode 100644 index 00000000..a540e144 --- /dev/null +++ b/app/llm/generation_arguments.py @@ -0,0 +1,9 @@ +class CompletionArguments: + """Arguments for the completion request""" + + def __init__( + self, max_tokens: int = None, temperature: float = None, stop: list[str] = None + ): + self.max_tokens = max_tokens + self.temperature = temperature + self.stop = stop diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py new file mode 100644 index 00000000..49a56f30 --- /dev/null +++ b/app/llm/llm_manager.py @@ -0,0 +1,102 @@ +import os + +import yaml + +from common import Singleton +from llm.wrapper import AbstractLlmWrapper + + +# TODO: Replace with pydantic in a future PR +def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: + if config["type"] == "openai": + from llm.wrapper import OpenAICompletionWrapper + + return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure": + from llm.wrapper import AzureCompletionWrapper + + return AzureCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_chat": + from llm.wrapper import OpenAIChatCompletionWrapper + + return OpenAIChatCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + api_key=config["api_key"], + ) + elif config["type"] == "azure_chat": + from llm.wrapper import AzureChatCompletionWrapper + + return AzureChatCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_embedding": + from llm.wrapper import OpenAIEmbeddingWrapper + + return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure_embedding": + from llm.wrapper import AzureEmbeddingWrapper + + return AzureEmbeddingWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "ollama": + from llm.wrapper import OllamaWrapper + + return OllamaWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + host=config["host"], + ) + else: + raise Exception(f"Unknown LLM type: {config['type']}") + + +class LlmManager(metaclass=Singleton): + entries: list[AbstractLlmWrapper] + + def __init__(self): + self.entries = [] + self.load_llms() + + def get_llm_by_id(self, llm_id): + for llm in self.entries: + if llm.id == llm_id: + return llm + + def load_llms(self): + path = os.environ.get("LLM_CONFIG_PATH") + if not path: + raise Exception("LLM_CONFIG_PATH not set") + + with open(path, "r") as file: + loaded_llms = yaml.safe_load(file) + + self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py new file mode 100644 index 00000000..5c15df30 --- /dev/null +++ b/app/llm/request_handler_interface.py @@ -0,0 +1,36 @@ +from abc import ABCMeta, abstractmethod + +from domain import IrisMessage +from llm.generation_arguments import CompletionArguments + + +class RequestHandlerInterface(metaclass=ABCMeta): + """Interface for the request handlers""" + + @classmethod + def __subclasshook__(cls, subclass): + return ( + hasattr(subclass, "completion") + and callable(subclass.completion) + and hasattr(subclass, "chat_completion") + and callable(subclass.chat_completion) + and hasattr(subclass, "create_embedding") + and callable(subclass.create_embedding) + ) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + @abstractmethod + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> [IrisMessage]: + """Create a completion from the chat messages""" + raise NotImplementedError + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py new file mode 100644 index 00000000..7e0dabff --- /dev/null +++ b/app/llm/wrapper/__init__.py @@ -0,0 +1,5 @@ +from llm.wrapper.abstract_llm_wrapper import * +from llm.wrapper.open_ai_completion_wrapper import * +from llm.wrapper.open_ai_chat_wrapper import * +from llm.wrapper.open_ai_embedding_wrapper import * +from llm.wrapper.ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py new file mode 100644 index 00000000..6d5e353e --- /dev/null +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -0,0 +1,62 @@ +from abc import ABCMeta, abstractmethod + +from domain import IrisMessage +from llm import CompletionArguments + + +class AbstractLlmWrapper(metaclass=ABCMeta): + """Abstract class for the llm wrappers""" + + id: str + name: str + description: str + + def __init__(self, id: str, name: str, description: str): + self.id = id + self.name = name + self.description = description + + +class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "completion") and callable(subclass.completion) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + +class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm chat completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "chat_completion") and callable( + subclass.chat_completion + ) + + @abstractmethod + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> IrisMessage: + """Create a completion from the chat messages""" + raise NotImplementedError + + +class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm embedding wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "create_embedding") and callable( + subclass.create_embedding + ) + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py new file mode 100644 index 00000000..9ce8e94b --- /dev/null +++ b/app/llm/wrapper/ollama_wrapper.py @@ -0,0 +1,50 @@ +from ollama import Client, Message + +from domain import IrisMessage, IrisMessageRole +from llm import CompletionArguments +from llm.wrapper import ( + AbstractLlmChatCompletionWrapper, + AbstractLlmCompletionWrapper, + AbstractLlmEmbeddingWrapper, +) + + +def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: + return [ + Message(role=message.role.value, content=message.text) for message in messages + ] + + +def convert_to_iris_message(message: Message) -> IrisMessage: + return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"]) + + +class OllamaWrapper( + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, +): + + def __init__(self, model: str, host: str, **kwargs): + super().__init__(**kwargs) + self.client = Client(host=host) # TODO: Add authentication (httpx auth?) + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + response = self.client.generate(model=self.model, prompt=prompt) + return response["response"] + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat( + model=self.model, messages=convert_to_ollama_messages(messages) + ) + return convert_to_iris_message(response["message"]) + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings(model=self.model, prompt=text) + return list(response) + + def __str__(self): + return f"Ollama('{self.model}')" diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py new file mode 100644 index 00000000..c6b68e25 --- /dev/null +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -0,0 +1,77 @@ +from openai.lib.azure import AzureOpenAI +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage + +from domain import IrisMessage, IrisMessageRole +from llm import CompletionArguments +from llm.wrapper import AbstractLlmChatCompletionWrapper + + +def convert_to_open_ai_messages( + messages: list[IrisMessage], +) -> list[ChatCompletionMessageParam]: + return [ + {"role": message.role.value, "content": message.text} for message in messages + ] + + +def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: + # Get IrisMessageRole from the string message.role + message_role = IrisMessageRole(message.role) + return IrisMessage(role=message_role, text=message.content) + + +class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): + + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) + self.client = client + self.model = model + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> IrisMessage: + response = self.client.chat.completions.create( + model=self.model, + messages=convert_to_open_ai_messages(messages), + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, + ) + return convert_to_iris_message(response.choices[0].message) + + +class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + + def __init__(self, model: str, api_key: str, **kwargs): + from openai import OpenAI + + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"OpenAIChat('{self.model}')" + + +class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + **kwargs, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py new file mode 100644 index 00000000..daac194a --- /dev/null +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -0,0 +1,58 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm import CompletionArguments +from llm.wrapper import AbstractLlmCompletionWrapper + + +class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): + + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) + self.client = client + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> any: + response = self.client.completions.create( + model=self.model, + prompt=prompt, + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, + ) + return response + + +class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + + def __init__(self, model: str, api_key: str, **kwargs): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"OpenAICompletion('{self.model}')" + + +class AzureCompletionWrapper(BaseOpenAICompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + **kwargs, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py new file mode 100644 index 00000000..88b425bd --- /dev/null +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -0,0 +1,57 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm.wrapper import ( + AbstractLlmEmbeddingWrapper, +) + + +class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): + + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) + self.client = client + self.model = model + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings.create( + model=self.model, + input=text, + encoding_format="float", + ) + return response.data[0].embedding + + +class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + + def __init__(self, model: str, api_key: str, **kwargs): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"OpenAIEmbedding('{self.model}')" + + +class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + **kwargs, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"AzureEmbedding('{self.model}')"