From 94f786a13057fc81109157ef0ca62e7a308ba671 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Mon, 12 Feb 2024 11:44:47 +0100 Subject: [PATCH 1/5] `LLM`: Add llm subsystem (#51) --- .flake8 | 13 +++ .github/labeler.yml | 3 + .github/workflows/lint.yml | 2 +- .github/workflows/pullrequest-labeler.yml | 10 ++ .pre-commit-config.yaml | 5 +- app/common/__init__.py | 1 + app/common/singleton.py | 7 ++ app/domain/__init__.py | 1 + app/domain/message.py | 19 ++++ app/llm/__init__.py | 3 + app/llm/basic_request_handler.py | 48 +++++++++ app/llm/generation_arguments.py | 9 ++ app/llm/llm_manager.py | 102 ++++++++++++++++++ app/llm/request_handler_interface.py | 36 +++++++ app/llm/wrapper/__init__.py | 5 + app/llm/wrapper/abstract_llm_wrapper.py | 62 +++++++++++ app/llm/wrapper/ollama_wrapper.py | 50 +++++++++ app/llm/wrapper/open_ai_chat_wrapper.py | 77 +++++++++++++ app/llm/wrapper/open_ai_completion_wrapper.py | 58 ++++++++++ app/llm/wrapper/open_ai_embedding_wrapper.py | 57 ++++++++++ 20 files changed, 565 insertions(+), 3 deletions(-) create mode 100644 .flake8 create mode 100644 .github/labeler.yml create mode 100644 .github/workflows/pullrequest-labeler.yml create mode 100644 app/common/__init__.py create mode 100644 app/common/singleton.py create mode 100644 app/domain/__init__.py create mode 100644 app/domain/message.py create mode 100644 app/llm/__init__.py create mode 100644 app/llm/basic_request_handler.py create mode 100644 app/llm/generation_arguments.py create mode 100644 app/llm/llm_manager.py create mode 100644 app/llm/request_handler_interface.py create mode 100644 app/llm/wrapper/__init__.py create mode 100644 app/llm/wrapper/abstract_llm_wrapper.py create mode 100644 app/llm/wrapper/ollama_wrapper.py create mode 100644 app/llm/wrapper/open_ai_chat_wrapper.py create mode 100644 app/llm/wrapper/open_ai_completion_wrapper.py create mode 100644 app/llm/wrapper/open_ai_embedding_wrapper.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..a5c48807 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +max-line-length = 120 +exclude = + .git, + __pycache__, + .idea +per-file-ignores = + # imported but unused + __init__.py: F401, F403 + open_ai_chat_wrapper.py: F811 + open_ai_completion_wrapper.py: F811 + open_ai_embedding_wrapper.py: F811 + diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000..62d31784 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,3 @@ +"component:LLM": + - changed-files: + - any-glob-to-any-file: app/llm/** diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 17c81c08..af785312 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" cache: 'pip' - name: Install Dependencies from requirements.txt diff --git a/.github/workflows/pullrequest-labeler.yml b/.github/workflows/pullrequest-labeler.yml new file mode 100644 index 00000000..f7739956 --- /dev/null +++ b/.github/workflows/pullrequest-labeler.yml @@ -0,0 +1,10 @@ +name: Pull Request Labeler +on: pull_request_target + +jobs: + label: + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v5 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42fd914f..b47d541b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,8 +5,9 @@ rev: stable hooks: - id: black - language_version: python3.11 + language_version: python3.12 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.0.0 hooks: - - id: flake8 \ No newline at end of file + - id: flake8 + language_version: python3.12 \ No newline at end of file diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 00000000..97e30c68 --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1 @@ +from common.singleton import Singleton diff --git a/app/common/singleton.py b/app/common/singleton.py new file mode 100644 index 00000000..3776cb92 --- /dev/null +++ b/app/common/singleton.py @@ -0,0 +1,7 @@ +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 00000000..b73080e7 --- /dev/null +++ b/app/domain/__init__.py @@ -0,0 +1 @@ +from domain.message import IrisMessage, IrisMessageRole diff --git a/app/domain/message.py b/app/domain/message.py new file mode 100644 index 00000000..b1f521cc --- /dev/null +++ b/app/domain/message.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class IrisMessageRole(Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +class IrisMessage: + role: IrisMessageRole + text: str + + def __init__(self, role: IrisMessageRole, text: str): + self.role = role + self.text = text + + def __str__(self): + return f"IrisMessage(role={self.role.value}, text='{self.text}')" diff --git a/app/llm/__init__.py b/app/llm/__init__.py new file mode 100644 index 00000000..33542f1c --- /dev/null +++ b/app/llm/__init__.py @@ -0,0 +1,3 @@ +from llm.request_handler_interface import RequestHandlerInterface +from llm.generation_arguments import * +from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py new file mode 100644 index 00000000..f348da1f --- /dev/null +++ b/app/llm/basic_request_handler.py @@ -0,0 +1,48 @@ +from domain import IrisMessage +from llm import RequestHandlerInterface, CompletionArguments +from llm.llm_manager import LlmManager +from llm.wrapper import ( + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, +) + +type BasicRequestHandlerModel = str + + +class BasicRequestHandler(RequestHandlerInterface): + model: BasicRequestHandlerModel + llm_manager: LlmManager + + def __init__(self, model: BasicRequestHandlerModel): + self.model = model + self.llm_manager = LlmManager() + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, AbstractLlmCompletionWrapper): + return llm.completion(prompt, arguments) + else: + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support completion" + ) + + def chat_completion( + self, messages: list[IrisMessage], arguments: CompletionArguments + ) -> IrisMessage: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, AbstractLlmChatCompletionWrapper): + return llm.chat_completion(messages, arguments) + else: + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support chat completion" + ) + + def create_embedding(self, text: str) -> list[float]: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, AbstractLlmEmbeddingWrapper): + return llm.create_embedding(text) + else: + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support embedding" + ) diff --git a/app/llm/generation_arguments.py b/app/llm/generation_arguments.py new file mode 100644 index 00000000..a540e144 --- /dev/null +++ b/app/llm/generation_arguments.py @@ -0,0 +1,9 @@ +class CompletionArguments: + """Arguments for the completion request""" + + def __init__( + self, max_tokens: int = None, temperature: float = None, stop: list[str] = None + ): + self.max_tokens = max_tokens + self.temperature = temperature + self.stop = stop diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py new file mode 100644 index 00000000..49a56f30 --- /dev/null +++ b/app/llm/llm_manager.py @@ -0,0 +1,102 @@ +import os + +import yaml + +from common import Singleton +from llm.wrapper import AbstractLlmWrapper + + +# TODO: Replace with pydantic in a future PR +def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: + if config["type"] == "openai": + from llm.wrapper import OpenAICompletionWrapper + + return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure": + from llm.wrapper import AzureCompletionWrapper + + return AzureCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_chat": + from llm.wrapper import OpenAIChatCompletionWrapper + + return OpenAIChatCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + api_key=config["api_key"], + ) + elif config["type"] == "azure_chat": + from llm.wrapper import AzureChatCompletionWrapper + + return AzureChatCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_embedding": + from llm.wrapper import OpenAIEmbeddingWrapper + + return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure_embedding": + from llm.wrapper import AzureEmbeddingWrapper + + return AzureEmbeddingWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "ollama": + from llm.wrapper import OllamaWrapper + + return OllamaWrapper( + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + host=config["host"], + ) + else: + raise Exception(f"Unknown LLM type: {config['type']}") + + +class LlmManager(metaclass=Singleton): + entries: list[AbstractLlmWrapper] + + def __init__(self): + self.entries = [] + self.load_llms() + + def get_llm_by_id(self, llm_id): + for llm in self.entries: + if llm.id == llm_id: + return llm + + def load_llms(self): + path = os.environ.get("LLM_CONFIG_PATH") + if not path: + raise Exception("LLM_CONFIG_PATH not set") + + with open(path, "r") as file: + loaded_llms = yaml.safe_load(file) + + self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py new file mode 100644 index 00000000..5c15df30 --- /dev/null +++ b/app/llm/request_handler_interface.py @@ -0,0 +1,36 @@ +from abc import ABCMeta, abstractmethod + +from domain import IrisMessage +from llm.generation_arguments import CompletionArguments + + +class RequestHandlerInterface(metaclass=ABCMeta): + """Interface for the request handlers""" + + @classmethod + def __subclasshook__(cls, subclass): + return ( + hasattr(subclass, "completion") + and callable(subclass.completion) + and hasattr(subclass, "chat_completion") + and callable(subclass.chat_completion) + and hasattr(subclass, "create_embedding") + and callable(subclass.create_embedding) + ) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + @abstractmethod + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> [IrisMessage]: + """Create a completion from the chat messages""" + raise NotImplementedError + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py new file mode 100644 index 00000000..7e0dabff --- /dev/null +++ b/app/llm/wrapper/__init__.py @@ -0,0 +1,5 @@ +from llm.wrapper.abstract_llm_wrapper import * +from llm.wrapper.open_ai_completion_wrapper import * +from llm.wrapper.open_ai_chat_wrapper import * +from llm.wrapper.open_ai_embedding_wrapper import * +from llm.wrapper.ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py new file mode 100644 index 00000000..6d5e353e --- /dev/null +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -0,0 +1,62 @@ +from abc import ABCMeta, abstractmethod + +from domain import IrisMessage +from llm import CompletionArguments + + +class AbstractLlmWrapper(metaclass=ABCMeta): + """Abstract class for the llm wrappers""" + + id: str + name: str + description: str + + def __init__(self, id: str, name: str, description: str): + self.id = id + self.name = name + self.description = description + + +class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "completion") and callable(subclass.completion) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + +class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm chat completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "chat_completion") and callable( + subclass.chat_completion + ) + + @abstractmethod + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> IrisMessage: + """Create a completion from the chat messages""" + raise NotImplementedError + + +class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm embedding wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "create_embedding") and callable( + subclass.create_embedding + ) + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py new file mode 100644 index 00000000..9ce8e94b --- /dev/null +++ b/app/llm/wrapper/ollama_wrapper.py @@ -0,0 +1,50 @@ +from ollama import Client, Message + +from domain import IrisMessage, IrisMessageRole +from llm import CompletionArguments +from llm.wrapper import ( + AbstractLlmChatCompletionWrapper, + AbstractLlmCompletionWrapper, + AbstractLlmEmbeddingWrapper, +) + + +def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: + return [ + Message(role=message.role.value, content=message.text) for message in messages + ] + + +def convert_to_iris_message(message: Message) -> IrisMessage: + return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"]) + + +class OllamaWrapper( + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, +): + + def __init__(self, model: str, host: str, **kwargs): + super().__init__(**kwargs) + self.client = Client(host=host) # TODO: Add authentication (httpx auth?) + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + response = self.client.generate(model=self.model, prompt=prompt) + return response["response"] + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat( + model=self.model, messages=convert_to_ollama_messages(messages) + ) + return convert_to_iris_message(response["message"]) + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings(model=self.model, prompt=text) + return list(response) + + def __str__(self): + return f"Ollama('{self.model}')" diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py new file mode 100644 index 00000000..c6b68e25 --- /dev/null +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -0,0 +1,77 @@ +from openai.lib.azure import AzureOpenAI +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage + +from domain import IrisMessage, IrisMessageRole +from llm import CompletionArguments +from llm.wrapper import AbstractLlmChatCompletionWrapper + + +def convert_to_open_ai_messages( + messages: list[IrisMessage], +) -> list[ChatCompletionMessageParam]: + return [ + {"role": message.role.value, "content": message.text} for message in messages + ] + + +def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: + # Get IrisMessageRole from the string message.role + message_role = IrisMessageRole(message.role) + return IrisMessage(role=message_role, text=message.content) + + +class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): + + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) + self.client = client + self.model = model + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> IrisMessage: + response = self.client.chat.completions.create( + model=self.model, + messages=convert_to_open_ai_messages(messages), + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, + ) + return convert_to_iris_message(response.choices[0].message) + + +class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + + def __init__(self, model: str, api_key: str, **kwargs): + from openai import OpenAI + + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"OpenAIChat('{self.model}')" + + +class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + **kwargs, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py new file mode 100644 index 00000000..daac194a --- /dev/null +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -0,0 +1,58 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm import CompletionArguments +from llm.wrapper import AbstractLlmCompletionWrapper + + +class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): + + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) + self.client = client + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> any: + response = self.client.completions.create( + model=self.model, + prompt=prompt, + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, + ) + return response + + +class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + + def __init__(self, model: str, api_key: str, **kwargs): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"OpenAICompletion('{self.model}')" + + +class AzureCompletionWrapper(BaseOpenAICompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + **kwargs, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py new file mode 100644 index 00000000..88b425bd --- /dev/null +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -0,0 +1,57 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm.wrapper import ( + AbstractLlmEmbeddingWrapper, +) + + +class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): + + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) + self.client = client + self.model = model + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings.create( + model=self.model, + input=text, + encoding_format="float", + ) + return response.data[0].embedding + + +class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + + def __init__(self, model: str, api_key: str, **kwargs): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"OpenAIEmbedding('{self.model}')" + + +class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + **kwargs, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model, **kwargs) + + def __str__(self): + return f"AzureEmbedding('{self.model}')" From fbb403a83d044779cdb2d2a047e0e48bfb98d62c Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Mon, 12 Feb 2024 11:46:50 +0100 Subject: [PATCH 2/5] `LLM`: Use pydantic for config parsing (#52) --- app/llm/basic_request_handler.py | 2 +- app/llm/llm_manager.py | 80 ++----------------- app/llm/wrapper/__init__.py | 27 ++++++- app/llm/wrapper/abstract_llm_wrapper.py | 8 +- app/llm/wrapper/ollama_wrapper.py | 20 +++-- app/llm/wrapper/open_ai_chat_wrapper.py | 51 +++++------- app/llm/wrapper/open_ai_completion_wrapper.py | 48 +++++------ app/llm/wrapper/open_ai_embedding_wrapper.py | 50 +++++------- requirements.txt | 1 + 9 files changed, 107 insertions(+), 180 deletions(-) diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index f348da1f..001d2dbb 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -1,7 +1,7 @@ from domain import IrisMessage from llm import RequestHandlerInterface, CompletionArguments from llm.llm_manager import LlmManager -from llm.wrapper import ( +from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmCompletionWrapper, AbstractLlmChatCompletionWrapper, AbstractLlmEmbeddingWrapper, diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index 49a56f30..af593d32 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -1,82 +1,16 @@ import os +from pydantic import BaseModel, Field + import yaml from common import Singleton -from llm.wrapper import AbstractLlmWrapper - - -# TODO: Replace with pydantic in a future PR -def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: - if config["type"] == "openai": - from llm.wrapper import OpenAICompletionWrapper - - return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) - elif config["type"] == "azure": - from llm.wrapper import AzureCompletionWrapper - - return AzureCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "openai_chat": - from llm.wrapper import OpenAIChatCompletionWrapper - - return OpenAIChatCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - api_key=config["api_key"], - ) - elif config["type"] == "azure_chat": - from llm.wrapper import AzureChatCompletionWrapper - - return AzureChatCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "openai_embedding": - from llm.wrapper import OpenAIEmbeddingWrapper - - return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) - elif config["type"] == "azure_embedding": - from llm.wrapper import AzureEmbeddingWrapper +from llm.wrapper import AbstractLlmWrapper, LlmWrapper - return AzureEmbeddingWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "ollama": - from llm.wrapper import OllamaWrapper - return OllamaWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - host=config["host"], - ) - else: - raise Exception(f"Unknown LLM type: {config['type']}") +# Small workaround to get pydantic discriminators working +class LlmList(BaseModel): + llms: list[LlmWrapper] = Field(discriminator="type") class LlmManager(metaclass=Singleton): @@ -99,4 +33,4 @@ def load_llms(self): with open(path, "r") as file: loaded_llms = yaml.safe_load(file) - self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] + self.entries = LlmList.parse_obj({"llms": loaded_llms}).llms diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 7e0dabff..c4807ec5 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,5 +1,24 @@ -from llm.wrapper.abstract_llm_wrapper import * -from llm.wrapper.open_ai_completion_wrapper import * -from llm.wrapper.open_ai_chat_wrapper import * -from llm.wrapper.open_ai_embedding_wrapper import * +from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper +from llm.wrapper.open_ai_completion_wrapper import ( + OpenAICompletionWrapper, + AzureCompletionWrapper, +) +from llm.wrapper.open_ai_chat_wrapper import ( + OpenAIChatCompletionWrapper, + AzureChatCompletionWrapper, +) +from llm.wrapper.open_ai_embedding_wrapper import ( + OpenAIEmbeddingWrapper, + AzureEmbeddingWrapper, +) from llm.wrapper.ollama_wrapper import OllamaWrapper + +type LlmWrapper = ( + OpenAICompletionWrapper + | AzureCompletionWrapper + | OpenAIChatCompletionWrapper + | AzureChatCompletionWrapper + | OpenAIEmbeddingWrapper + | AzureEmbeddingWrapper + | OllamaWrapper +) diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py index 6d5e353e..057b3aca 100644 --- a/app/llm/wrapper/abstract_llm_wrapper.py +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -1,21 +1,17 @@ from abc import ABCMeta, abstractmethod +from pydantic import BaseModel from domain import IrisMessage from llm import CompletionArguments -class AbstractLlmWrapper(metaclass=ABCMeta): +class AbstractLlmWrapper(BaseModel, metaclass=ABCMeta): """Abstract class for the llm wrappers""" id: str name: str description: str - def __init__(self, id: str, name: str, description: str): - self.id = id - self.name = name - self.description = description - class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): """Abstract class for the llm completion wrappers""" diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 9ce8e94b..4ea0e9b0 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,8 +1,10 @@ +from typing import Literal, Any + from ollama import Client, Message from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import ( +from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmChatCompletionWrapper, AbstractLlmCompletionWrapper, AbstractLlmEmbeddingWrapper, @@ -24,26 +26,28 @@ class OllamaWrapper( AbstractLlmChatCompletionWrapper, AbstractLlmEmbeddingWrapper, ): + type: Literal["ollama"] + model: str + host: str + _client: Client - def __init__(self, model: str, host: str, **kwargs): - super().__init__(**kwargs) - self.client = Client(host=host) # TODO: Add authentication (httpx auth?) - self.model = model + def model_post_init(self, __context: Any) -> None: + self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?) def completion(self, prompt: str, arguments: CompletionArguments) -> str: - response = self.client.generate(model=self.model, prompt=prompt) + response = self._client.generate(model=self.model, prompt=prompt) return response["response"] def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> any: - response = self.client.chat( + response = self._client.chat( model=self.model, messages=convert_to_ollama_messages(messages) ) return convert_to_iris_message(response["message"]) def create_embedding(self, text: str) -> list[float]: - response = self.client.embeddings(model=self.model, prompt=text) + response = self._client.embeddings(model=self.model, prompt=text) return list(response) def __str__(self): diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index c6b68e25..6a605ad5 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,9 +1,11 @@ +from typing import Literal, Any +from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import AbstractLlmChatCompletionWrapper +from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper def convert_to_open_ai_messages( @@ -21,16 +23,14 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> IrisMessage: - response = self.client.chat.completions.create( + response = self._client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), temperature=arguments.temperature, @@ -41,37 +41,28 @@ def chat_completion( class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + type: Literal["openai_chat"] - def __init__(self, model: str, api_key: str, **kwargs): - from openai import OpenAI - - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAIChat('{self.model}')" class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_chat"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index daac194a..22fe4ed2 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -1,19 +1,18 @@ +from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI from llm import CompletionArguments -from llm.wrapper import AbstractLlmCompletionWrapper +from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def completion(self, prompt: str, arguments: CompletionArguments) -> any: - response = self.client.completions.create( + response = self._client.completions.create( model=self.model, prompt=prompt, temperature=arguments.temperature, @@ -24,35 +23,28 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + type: Literal["openai_completion"] - def __init__(self, model: str, api_key: str, **kwargs): - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAICompletion('{self.model}')" class AzureCompletionWrapper(BaseOpenAICompletionWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_completion"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 88b425bd..99c397c9 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -1,20 +1,17 @@ +from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI -from llm.wrapper import ( - AbstractLlmEmbeddingWrapper, -) +from llm.wrapper.abstract_llm_wrapper import AbstractLlmEmbeddingWrapper class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def create_embedding(self, text: str) -> list[float]: - response = self.client.embeddings.create( + response = self._client.embeddings.create( model=self.model, input=text, encoding_format="float", @@ -23,35 +20,28 @@ def create_embedding(self, text: str) -> list[float]: class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + type: Literal["openai_embedding"] - def __init__(self, model: str, api_key: str, **kwargs): - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAIEmbedding('{self.model}')" class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_embedding"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureEmbedding('{self.model}')" diff --git a/requirements.txt b/requirements.txt index 0b8fabbb..71c9e37e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ uvicorn==0.23.1 black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 +pydantic==2.6.1 From f967d82cce15e5ba78919c51220eeec70facee51 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:48:42 +0100 Subject: [PATCH 3/5] Bump uvicorn from 0.23.1 to 0.27.1 (#56) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 71c9e37e..251c821d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ openai==1.11.1 ollama==0.1.6 fastapi==0.109.2 -uvicorn==0.23.1 +uvicorn==0.27.1 black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 From 8a223702f9b2a4e022277ee728be52981377651c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:49:09 +0100 Subject: [PATCH 4/5] Bump pre-commit from 3.6.0 to 3.6.1 (#55) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 251c821d..f35cd167 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ fastapi==0.109.2 uvicorn==0.27.1 black==24.1.1 flake8==7.0.0 -pre-commit==3.6.0 +pre-commit==3.6.1 pydantic==2.6.1 From b6d5ff7190bf0398bec490b27d1e67208a4ad22e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:49:26 +0100 Subject: [PATCH 5/5] Bump openai from 1.11.1 to 1.12.0 (#54) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f35cd167..3b4afc16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -openai==1.11.1 +openai==1.12.0 ollama==0.1.6 fastapi==0.109.2 uvicorn==0.27.1