From 429a4a005e6cdc4ed0d9197c25017a206632df9e Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 13:56:22 +0100 Subject: [PATCH 01/12] First draft of LLM subsystem --- .github/workflows/lint.yml | 2 +- app/common/__init__.py | 1 + app/common/singleton.py | 7 ++++ app/domain/__init__.py | 1 + app/domain/message.py | 4 ++ app/llm/__init__.py | 4 ++ app/llm/basic_request_handler.py | 36 ++++++++++++++++++ app/llm/generation_arguments.py | 7 ++++ app/llm/llm_manager.py | 26 +++++++++++++ app/llm/request_handler_interface.py | 36 ++++++++++++++++++ app/llm/wrapper/__init__.py | 3 ++ app/llm/wrapper/azure_chat_wrapper.py | 32 ++++++++++++++++ app/llm/wrapper/llm_wrapper_interface.py | 47 +++++++++++++++++++++++ app/llm/wrapper/ollama_wrapper.py | 48 ++++++++++++++++++++++++ app/llm/wrapper/open_ai_chat_wrapper.py | 35 +++++++++++++++++ 15 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 app/common/__init__.py create mode 100644 app/common/singleton.py create mode 100644 app/domain/__init__.py create mode 100644 app/domain/message.py create mode 100644 app/llm/__init__.py create mode 100644 app/llm/basic_request_handler.py create mode 100644 app/llm/generation_arguments.py create mode 100644 app/llm/llm_manager.py create mode 100644 app/llm/request_handler_interface.py create mode 100644 app/llm/wrapper/__init__.py create mode 100644 app/llm/wrapper/azure_chat_wrapper.py create mode 100644 app/llm/wrapper/llm_wrapper_interface.py create mode 100644 app/llm/wrapper/ollama_wrapper.py create mode 100644 app/llm/wrapper/open_ai_chat_wrapper.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 17c81c08..af785312 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" cache: 'pip' - name: Install Dependencies from requirements.txt diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 00000000..e190c8ba --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1 @@ +from singleton import Singleton diff --git a/app/common/singleton.py b/app/common/singleton.py new file mode 100644 index 00000000..3776cb92 --- /dev/null +++ b/app/common/singleton.py @@ -0,0 +1,7 @@ +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 00000000..5ead9e65 --- /dev/null +++ b/app/domain/__init__.py @@ -0,0 +1 @@ +from message import IrisMessage diff --git a/app/domain/message.py b/app/domain/message.py new file mode 100644 index 00000000..b5fe26f7 --- /dev/null +++ b/app/domain/message.py @@ -0,0 +1,4 @@ +class IrisMessage: + def __init__(self, role, message_text): + self.role = role + self.message_text = message_text diff --git a/app/llm/__init__.py b/app/llm/__init__.py new file mode 100644 index 00000000..73b39dc6 --- /dev/null +++ b/app/llm/__init__.py @@ -0,0 +1,4 @@ +from generation_arguments import CompletionArguments +from request_handler_interface import RequestHandlerInterface +from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel +from llm_manager import LlmManager diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py new file mode 100644 index 00000000..0c284002 --- /dev/null +++ b/app/llm/basic_request_handler.py @@ -0,0 +1,36 @@ +from domain import IrisMessage +from llm import LlmManager +from llm import RequestHandlerInterface, CompletionArguments +from llm.wrapper import LlmCompletionWrapperInterface, LlmChatCompletionWrapperInterface, LlmEmbeddingWrapperInterface + +type BasicRequestHandlerModel = str + + +class BasicRequestHandler(RequestHandlerInterface): + model: BasicRequestHandlerModel + llm_manager: LlmManager + + def __init__(self, model: BasicRequestHandlerModel): + self.model = model + self.llm_manager = LlmManager() + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, LlmCompletionWrapperInterface): + return llm.completion(prompt, arguments) + else: + raise NotImplementedError + + def chat_completion(self, messages: list[IrisMessage], arguments: CompletionArguments) -> IrisMessage: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, LlmChatCompletionWrapperInterface): + return llm.chat_completion(messages, arguments) + else: + raise NotImplementedError + + def create_embedding(self, text: str) -> list[float]: + llm = self.llm_manager.get_llm_by_id(self.model).llm + if isinstance(llm, LlmEmbeddingWrapperInterface): + return llm.create_embedding(text) + else: + raise NotImplementedError diff --git a/app/llm/generation_arguments.py b/app/llm/generation_arguments.py new file mode 100644 index 00000000..37a4af19 --- /dev/null +++ b/app/llm/generation_arguments.py @@ -0,0 +1,7 @@ +class CompletionArguments: + """Arguments for the completion request""" + + def __init__(self, max_tokens: int, temperature: float, stop: list[str]): + self.max_tokens = max_tokens + self.temperature = temperature + self.stop = stop diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py new file mode 100644 index 00000000..8dd649d8 --- /dev/null +++ b/app/llm/llm_manager.py @@ -0,0 +1,26 @@ +from common import Singleton +from llm.wrapper import LlmWrapperInterface + + +class LlmManagerEntry: + id: str + llm: LlmWrapperInterface + + def __init__(self, id: str, llm: LlmWrapperInterface): + self.id = id + self.llm = llm + + +class LlmManager(metaclass=Singleton): + llms: list[LlmManagerEntry] + + def __init__(self): + self.llms = [] + + def get_llm_by_id(self, llm_id): + for llm in self.llms: + if llm.id == llm_id: + return llm + + def add_llm(self, id: str, llm: LlmWrapperInterface): + self.llms.append(LlmManagerEntry(id, llm)) diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py new file mode 100644 index 00000000..5c15df30 --- /dev/null +++ b/app/llm/request_handler_interface.py @@ -0,0 +1,36 @@ +from abc import ABCMeta, abstractmethod + +from domain import IrisMessage +from llm.generation_arguments import CompletionArguments + + +class RequestHandlerInterface(metaclass=ABCMeta): + """Interface for the request handlers""" + + @classmethod + def __subclasshook__(cls, subclass): + return ( + hasattr(subclass, "completion") + and callable(subclass.completion) + and hasattr(subclass, "chat_completion") + and callable(subclass.chat_completion) + and hasattr(subclass, "create_embedding") + and callable(subclass.create_embedding) + ) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + @abstractmethod + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> [IrisMessage]: + """Create a completion from the chat messages""" + raise NotImplementedError + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py new file mode 100644 index 00000000..6aef2481 --- /dev/null +++ b/app/llm/wrapper/__init__.py @@ -0,0 +1,3 @@ +from llm_wrapper_interface import * +from open_ai_chat_wrapper import * +from ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/azure_chat_wrapper.py b/app/llm/wrapper/azure_chat_wrapper.py new file mode 100644 index 00000000..9022e3ef --- /dev/null +++ b/app/llm/wrapper/azure_chat_wrapper.py @@ -0,0 +1,32 @@ +from openai.lib.azure import AzureOpenAI + +from llm import CompletionArguments +from llm.wrapper import LlmChatCompletionWrapperInterface, convert_to_open_ai_messages + + +class AzureChatCompletionWrapper(LlmChatCompletionWrapperInterface): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + self.client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + self.model = model + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat.completions.create( + model=self.model, + messages=convert_to_open_ai_messages(messages), + ) + return response diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/llm_wrapper_interface.py new file mode 100644 index 00000000..e94f3e3b --- /dev/null +++ b/app/llm/wrapper/llm_wrapper_interface.py @@ -0,0 +1,47 @@ +from abc import ABCMeta, abstractmethod + +from llm import CompletionArguments + +type LlmWrapperInterface = LlmCompletionWrapperInterface | LlmChatCompletionWrapperInterface | LlmEmbeddingWrapperInterface + + +class LlmCompletionWrapperInterface(metaclass=ABCMeta): + """Interface for the llm completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return (hasattr(subclass, 'completion') and + callable(subclass.completion)) + + @abstractmethod + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError + + +class LlmChatCompletionWrapperInterface(metaclass=ABCMeta): + """Interface for the llm chat completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return (hasattr(subclass, 'chat_completion') and + callable(subclass.chat_completion)) + + @abstractmethod + def chat_completion(self, messages: list[any], arguments: CompletionArguments) -> any: + """Create a completion from the chat messages""" + raise NotImplementedError + + +class LlmEmbeddingWrapperInterface(metaclass=ABCMeta): + """Interface for the llm embedding wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return (hasattr(subclass, 'create_embedding') and + callable(subclass.create_embedding)) + + @abstractmethod + def create_embedding(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py new file mode 100644 index 00000000..d4ca69bb --- /dev/null +++ b/app/llm/wrapper/ollama_wrapper.py @@ -0,0 +1,48 @@ +from ollama import Client, Message +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from domain import IrisMessage +from llm import CompletionArguments +from llm.wrapper import ( + LlmChatCompletionWrapperInterface, + LlmCompletionWrapperInterface, + LlmEmbeddingWrapperInterface, +) + + +def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: + return [ + Message(role=message.role, content=message.message_text) for message in messages + ] + + +def convert_to_iris_message(message: Message) -> IrisMessage: + return IrisMessage(role=message.role, message_text=message.content) + + +class OllamaWrapper( + LlmCompletionWrapperInterface, + LlmChatCompletionWrapperInterface, + LlmEmbeddingWrapperInterface, +): + + def __init__(self, model: str, host: str): + self.client = Client(host=host) # TODO: Add authentication (httpx auth?) + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> str: + response = self.client.generate(model=self.model, prompt=prompt) + return response["response"] + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat( + model=self.model, messages=convert_to_ollama_messages(messages) + ) + return convert_to_iris_message(response["message"]) + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings(model=self.model, prompt=text) + return response diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py new file mode 100644 index 00000000..db04de18 --- /dev/null +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -0,0 +1,35 @@ +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from domain import IrisMessage +from llm import CompletionArguments +from llm.wrapper import LlmChatCompletionWrapperInterface + + +def convert_to_open_ai_messages( + messages: list[IrisMessage], +) -> list[ChatCompletionMessageParam]: + return [ + ChatCompletionMessageParam(role=message.role, content=message.message_text) + for message in messages + ] + + +def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: + return IrisMessage(role=message.role, message_text=message.content) + + +class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): + + def __init__(self, model: str, api_key: str): + self.client = OpenAI(api_key=api_key) + self.model = model + + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: + response = self.client.chat.completions.create( + model=self.model, + messages=convert_to_open_ai_messages(messages), + ) + return response From 898d1fc9a3c7ee278301bd49ecda1d612d4ebc82 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 15:04:46 +0100 Subject: [PATCH 02/12] Fix style issues --- .flake8 | 10 ++++++++++ .pre-commit-config.yaml | 5 +++-- app/llm/basic_request_handler.py | 10 ++++++++-- app/llm/wrapper/llm_wrapper_interface.py | 23 +++++++++++++++-------- app/llm/wrapper/ollama_wrapper.py | 2 -- 5 files changed, 36 insertions(+), 14 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..18f3c3b8 --- /dev/null +++ b/.flake8 @@ -0,0 +1,10 @@ +[flake8] +max-line-length = 120 +exclude = + .git, + __pycache__, + .idea +per-file-ignores = + # imported but unused + __init__.py: F401, F403 + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42fd914f..b47d541b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,8 +5,9 @@ rev: stable hooks: - id: black - language_version: python3.11 + language_version: python3.12 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.0.0 hooks: - - id: flake8 \ No newline at end of file + - id: flake8 + language_version: python3.12 \ No newline at end of file diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 0c284002..a572c151 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -1,7 +1,11 @@ from domain import IrisMessage from llm import LlmManager from llm import RequestHandlerInterface, CompletionArguments -from llm.wrapper import LlmCompletionWrapperInterface, LlmChatCompletionWrapperInterface, LlmEmbeddingWrapperInterface +from llm.wrapper import ( + LlmCompletionWrapperInterface, + LlmChatCompletionWrapperInterface, + LlmEmbeddingWrapperInterface, +) type BasicRequestHandlerModel = str @@ -21,7 +25,9 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: else: raise NotImplementedError - def chat_completion(self, messages: list[IrisMessage], arguments: CompletionArguments) -> IrisMessage: + def chat_completion( + self, messages: list[IrisMessage], arguments: CompletionArguments + ) -> IrisMessage: llm = self.llm_manager.get_llm_by_id(self.model).llm if isinstance(llm, LlmChatCompletionWrapperInterface): return llm.chat_completion(messages, arguments) diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/llm_wrapper_interface.py index e94f3e3b..73aade5d 100644 --- a/app/llm/wrapper/llm_wrapper_interface.py +++ b/app/llm/wrapper/llm_wrapper_interface.py @@ -2,7 +2,11 @@ from llm import CompletionArguments -type LlmWrapperInterface = LlmCompletionWrapperInterface | LlmChatCompletionWrapperInterface | LlmEmbeddingWrapperInterface +type LlmWrapperInterface = ( + LlmCompletionWrapperInterface + | LlmChatCompletionWrapperInterface + | LlmEmbeddingWrapperInterface +) class LlmCompletionWrapperInterface(metaclass=ABCMeta): @@ -10,8 +14,7 @@ class LlmCompletionWrapperInterface(metaclass=ABCMeta): @classmethod def __subclasshook__(cls, subclass): - return (hasattr(subclass, 'completion') and - callable(subclass.completion)) + return hasattr(subclass, "completion") and callable(subclass.completion) @abstractmethod def completion(self, prompt: str, arguments: CompletionArguments) -> str: @@ -24,11 +27,14 @@ class LlmChatCompletionWrapperInterface(metaclass=ABCMeta): @classmethod def __subclasshook__(cls, subclass): - return (hasattr(subclass, 'chat_completion') and - callable(subclass.chat_completion)) + return hasattr(subclass, "chat_completion") and callable( + subclass.chat_completion + ) @abstractmethod - def chat_completion(self, messages: list[any], arguments: CompletionArguments) -> any: + def chat_completion( + self, messages: list[any], arguments: CompletionArguments + ) -> any: """Create a completion from the chat messages""" raise NotImplementedError @@ -38,8 +44,9 @@ class LlmEmbeddingWrapperInterface(metaclass=ABCMeta): @classmethod def __subclasshook__(cls, subclass): - return (hasattr(subclass, 'create_embedding') and - callable(subclass.create_embedding)) + return hasattr(subclass, "create_embedding") and callable( + subclass.create_embedding + ) @abstractmethod def create_embedding(self, text: str) -> list[float]: diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index d4ca69bb..5ca682b8 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,6 +1,4 @@ from ollama import Client, Message -from openai import OpenAI -from openai.types.chat import ChatCompletionMessageParam from domain import IrisMessage from llm import CompletionArguments From 14ae279976279fd5540323bf9fc7867ae3e4cfaa Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 17:32:49 +0100 Subject: [PATCH 03/12] Small refinements --- .flake8 | 3 ++ app/common/__init__.py | 2 +- app/domain/__init__.py | 2 +- app/llm/__init__.py | 8 +-- app/llm/basic_request_handler.py | 12 +++-- app/llm/wrapper/__init__.py | 6 +-- app/llm/wrapper/azure_chat_wrapper.py | 32 ------------ app/llm/wrapper/ollama_wrapper.py | 7 ++- app/llm/wrapper/open_ai_chat_wrapper.py | 34 ++++++++++++ app/llm/wrapper/open_ai_completion_wrapper.py | 52 +++++++++++++++++++ app/llm/wrapper/open_ai_embedding_wrapper.py | 51 ++++++++++++++++++ 11 files changed, 163 insertions(+), 46 deletions(-) delete mode 100644 app/llm/wrapper/azure_chat_wrapper.py create mode 100644 app/llm/wrapper/open_ai_completion_wrapper.py create mode 100644 app/llm/wrapper/open_ai_embedding_wrapper.py diff --git a/.flake8 b/.flake8 index 18f3c3b8..a5c48807 100644 --- a/.flake8 +++ b/.flake8 @@ -7,4 +7,7 @@ exclude = per-file-ignores = # imported but unused __init__.py: F401, F403 + open_ai_chat_wrapper.py: F811 + open_ai_completion_wrapper.py: F811 + open_ai_embedding_wrapper.py: F811 diff --git a/app/common/__init__.py b/app/common/__init__.py index e190c8ba..97e30c68 100644 --- a/app/common/__init__.py +++ b/app/common/__init__.py @@ -1 +1 @@ -from singleton import Singleton +from common.singleton import Singleton diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 5ead9e65..270c228a 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -1 +1 @@ -from message import IrisMessage +from domain.message import IrisMessage diff --git a/app/llm/__init__.py b/app/llm/__init__.py index 73b39dc6..51227f24 100644 --- a/app/llm/__init__.py +++ b/app/llm/__init__.py @@ -1,4 +1,4 @@ -from generation_arguments import CompletionArguments -from request_handler_interface import RequestHandlerInterface -from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel -from llm_manager import LlmManager +from llm.generation_arguments import CompletionArguments +from llm.request_handler_interface import RequestHandlerInterface +from llm.llm_manager import LlmManager +from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index a572c151..fbeacb76 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -23,7 +23,9 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: if isinstance(llm, LlmCompletionWrapperInterface): return llm.completion(prompt, arguments) else: - raise NotImplementedError + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support completion" + ) def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments @@ -32,11 +34,15 @@ def chat_completion( if isinstance(llm, LlmChatCompletionWrapperInterface): return llm.chat_completion(messages, arguments) else: - raise NotImplementedError + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support chat completion" + ) def create_embedding(self, text: str) -> list[float]: llm = self.llm_manager.get_llm_by_id(self.model).llm if isinstance(llm, LlmEmbeddingWrapperInterface): return llm.create_embedding(text) else: - raise NotImplementedError + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support embedding" + ) diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 6aef2481..6ddf4569 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,3 +1,3 @@ -from llm_wrapper_interface import * -from open_ai_chat_wrapper import * -from ollama_wrapper import OllamaWrapper +from llm.wrapper.llm_wrapper_interface import * +from llm.wrapper.open_ai_chat_wrapper import * +from llm.wrapper.ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/azure_chat_wrapper.py b/app/llm/wrapper/azure_chat_wrapper.py deleted file mode 100644 index 9022e3ef..00000000 --- a/app/llm/wrapper/azure_chat_wrapper.py +++ /dev/null @@ -1,32 +0,0 @@ -from openai.lib.azure import AzureOpenAI - -from llm import CompletionArguments -from llm.wrapper import LlmChatCompletionWrapperInterface, convert_to_open_ai_messages - - -class AzureChatCompletionWrapper(LlmChatCompletionWrapperInterface): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - ): - self.client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, - ) - self.model = model - - def chat_completion( - self, messages: list[any], arguments: CompletionArguments - ) -> any: - response = self.client.chat.completions.create( - model=self.model, - messages=convert_to_open_ai_messages(messages), - ) - return response diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 5ca682b8..9dc8131a 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -16,7 +16,7 @@ def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: def convert_to_iris_message(message: Message) -> IrisMessage: - return IrisMessage(role=message.role, message_text=message.content) + return IrisMessage(role=message["role"], message_text=message["content"]) class OllamaWrapper( @@ -43,4 +43,7 @@ def chat_completion( def create_embedding(self, text: str) -> list[float]: response = self.client.embeddings(model=self.model, prompt=text) - return response + return list(response) + + def __str__(self): + return f"Ollama('{self.model}')" diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index db04de18..a21a9951 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,4 +1,5 @@ from openai import OpenAI +from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam from domain import IrisMessage @@ -25,11 +26,44 @@ def __init__(self, model: str, api_key: str): self.client = OpenAI(api_key=api_key) self.model = model + def __init__(self, client, model: str): + self.client = client + self.model = model + def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> any: response = self.client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, ) return response + + def __str__(self): + return f"OpenAIChat('{self.model}')" + + +class AzureChatCompletionWrapper(OpenAIChatCompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model) + + def __str__(self): + return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py new file mode 100644 index 00000000..1d2eded3 --- /dev/null +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -0,0 +1,52 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm import CompletionArguments +from llm.wrapper import LlmCompletionWrapperInterface + + +class OpenAICompletionWrapper(LlmCompletionWrapperInterface): + + def __init__(self, model: str, api_key: str): + self.client = OpenAI(api_key=api_key) + self.model = model + + def __init__(self, client, model: str): + self.client = client + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> any: + response = self.client.completions.create( + model=self.model, + prompt=prompt, + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, + ) + return response + + def __str__(self): + return f"OpenAICompletion('{self.model}')" + + +class AzureCompletionWrapper(OpenAICompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model) + + def __str__(self): + return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py new file mode 100644 index 00000000..4983d3ee --- /dev/null +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -0,0 +1,51 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm.wrapper import ( + LlmEmbeddingWrapperInterface, +) + + +class OpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): + + def __init__(self, model: str, api_key: str): + self.client = OpenAI(api_key=api_key) + self.model = model + + def __init__(self, client, model: str): + self.client = client + self.model = model + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings.create( + model=self.model, + input=text, + encoding_format="float", + ) + return response.data[0].embedding + + def __str__(self): + return f"OpenAIEmbedding('{self.model}')" + + +class AzureEmbeddingWrapper(OpenAIEmbeddingWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model) + + def __str__(self): + return f"AzureEmbedding('{self.model}')" From 2f60e1ab18cea0a8cbe71a16eb135841866f330a Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 17:37:50 +0100 Subject: [PATCH 04/12] Add PR labeler --- .github/labeler.yml | 3 +++ .github/workflows/pullrequest-labeler.yml | 10 ++++++++++ 2 files changed, 13 insertions(+) create mode 100644 .github/labeler.yml create mode 100644 .github/workflows/pullrequest-labeler.yml diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000..62d31784 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,3 @@ +"component:LLM": + - changed-files: + - any-glob-to-any-file: app/llm/** diff --git a/.github/workflows/pullrequest-labeler.yml b/.github/workflows/pullrequest-labeler.yml new file mode 100644 index 00000000..90c20ceb --- /dev/null +++ b/.github/workflows/pullrequest-labeler.yml @@ -0,0 +1,10 @@ +name: Pull Request Labeler +on: [pull_request_target] + +jobs: + label: + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v5 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" From 4f5c8be768594b14f9a4c9d5ad6931517a0fbfca Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 19:34:14 +0100 Subject: [PATCH 05/12] Add IrisMessageRole and improve OpenAI wrappers --- .github/workflows/pullrequest-labeler.yml | 2 +- app/domain/__init__.py | 2 +- app/domain/message.py | 19 ++++- app/llm/generation_arguments.py | 4 +- app/llm/llm_manager.py | 81 +++++++++++++++++-- app/llm/wrapper/__init__.py | 2 + app/llm/wrapper/llm_wrapper_interface.py | 3 +- app/llm/wrapper/ollama_wrapper.py | 9 ++- app/llm/wrapper/open_ai_chat_wrapper.py | 31 ++++--- app/llm/wrapper/open_ai_completion_wrapper.py | 16 ++-- app/llm/wrapper/open_ai_embedding_wrapper.py | 16 ++-- 11 files changed, 148 insertions(+), 37 deletions(-) diff --git a/.github/workflows/pullrequest-labeler.yml b/.github/workflows/pullrequest-labeler.yml index 90c20ceb..f7739956 100644 --- a/.github/workflows/pullrequest-labeler.yml +++ b/.github/workflows/pullrequest-labeler.yml @@ -1,5 +1,5 @@ name: Pull Request Labeler -on: [pull_request_target] +on: pull_request_target jobs: label: diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 270c228a..b73080e7 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -1 +1 @@ -from domain.message import IrisMessage +from domain.message import IrisMessage, IrisMessageRole diff --git a/app/domain/message.py b/app/domain/message.py index b5fe26f7..960750a9 100644 --- a/app/domain/message.py +++ b/app/domain/message.py @@ -1,4 +1,21 @@ +from enum import Enum + + +class IrisMessageRole(Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + class IrisMessage: - def __init__(self, role, message_text): + role: IrisMessageRole + message_text: str + + def __init__(self, role: IrisMessageRole, message_text: str): self.role = role self.message_text = message_text + + def __str__(self): + return ( + f"IrisMessage(role={self.role.value}, message_text='{self.message_text}')" + ) diff --git a/app/llm/generation_arguments.py b/app/llm/generation_arguments.py index 37a4af19..a540e144 100644 --- a/app/llm/generation_arguments.py +++ b/app/llm/generation_arguments.py @@ -1,7 +1,9 @@ class CompletionArguments: """Arguments for the completion request""" - def __init__(self, max_tokens: int, temperature: float, stop: list[str]): + def __init__( + self, max_tokens: int = None, temperature: float = None, stop: list[str] = None + ): self.max_tokens = max_tokens self.temperature = temperature self.stop = stop diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index 8dd649d8..bc6de680 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -1,14 +1,77 @@ +import os + +import yaml + from common import Singleton from llm.wrapper import LlmWrapperInterface +def create_llm_wrapper(config: dict) -> LlmWrapperInterface: + if config["type"] == "openai": + from llm.wrapper import OpenAICompletionWrapper + + return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure": + from llm.wrapper import AzureCompletionWrapper + + return AzureCompletionWrapper( + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_chat": + from llm.wrapper import OpenAIChatCompletionWrapper + + return OpenAIChatCompletionWrapper( + model=config["model"], api_key=config["api_key"] + ) + elif config["type"] == "azure_chat": + from llm.wrapper import AzureChatCompletionWrapper + + return AzureChatCompletionWrapper( + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_embedding": + from llm.wrapper import OpenAIEmbeddingWrapper + + return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure_embedding": + from llm.wrapper import AzureEmbeddingWrapper + + return AzureEmbeddingWrapper( + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "ollama": + from llm.wrapper import OllamaWrapper + + return OllamaWrapper( + model=config["model"], + host=config["host"], + ) + else: + raise Exception(f"Unknown LLM type: {config['type']}") + + class LlmManagerEntry: id: str llm: LlmWrapperInterface - def __init__(self, id: str, llm: LlmWrapperInterface): - self.id = id - self.llm = llm + def __init__(self, config: dict): + self.id = config["id"] + self.llm = create_llm_wrapper(config) + + def __str__(self): + return f"{self.id}: {self.llm}" class LlmManager(metaclass=Singleton): @@ -16,11 +79,19 @@ class LlmManager(metaclass=Singleton): def __init__(self): self.llms = [] + self.load_llms() def get_llm_by_id(self, llm_id): for llm in self.llms: if llm.id == llm_id: return llm - def add_llm(self, id: str, llm: LlmWrapperInterface): - self.llms.append(LlmManagerEntry(id, llm)) + def load_llms(self): + path = os.environ.get("LLM_CONFIG_PATH") + if not path: + raise Exception("LLM_CONFIG_PATH not set") + + with open(path, "r") as file: + loaded_llms = yaml.safe_load(file) + + self.llms = [LlmManagerEntry(llm) for llm in loaded_llms] diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 6ddf4569..4364afa0 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,3 +1,5 @@ from llm.wrapper.llm_wrapper_interface import * +from llm.wrapper.open_ai_completion_wrapper import * from llm.wrapper.open_ai_chat_wrapper import * +from llm.wrapper.open_ai_embedding_wrapper import * from llm.wrapper.ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/llm_wrapper_interface.py index 73aade5d..b1e79acb 100644 --- a/app/llm/wrapper/llm_wrapper_interface.py +++ b/app/llm/wrapper/llm_wrapper_interface.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod +from domain import IrisMessage from llm import CompletionArguments type LlmWrapperInterface = ( @@ -34,7 +35,7 @@ def __subclasshook__(cls, subclass): @abstractmethod def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> any: + ) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 9dc8131a..8d0cd397 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,6 +1,6 @@ from ollama import Client, Message -from domain import IrisMessage +from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments from llm.wrapper import ( LlmChatCompletionWrapperInterface, @@ -11,12 +11,15 @@ def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: return [ - Message(role=message.role, content=message.message_text) for message in messages + Message(role=message.role.value, content=message.message_text) + for message in messages ] def convert_to_iris_message(message: Message) -> IrisMessage: - return IrisMessage(role=message["role"], message_text=message["content"]) + return IrisMessage( + role=IrisMessageRole(message["role"]), message_text=message["content"] + ) class OllamaWrapper( diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index a21a9951..be1016a0 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,8 +1,7 @@ -from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam -from domain import IrisMessage +from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments from llm.wrapper import LlmChatCompletionWrapperInterface @@ -11,20 +10,18 @@ def convert_to_open_ai_messages( messages: list[IrisMessage], ) -> list[ChatCompletionMessageParam]: return [ - ChatCompletionMessageParam(role=message.role, content=message.message_text) + {"role": message.role.value, "content": message.message_text} for message in messages ] def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: - return IrisMessage(role=message.role, message_text=message.content) + # Get IrisMessageRole from the string message.role + message_role = IrisMessageRole(message.role) + return IrisMessage(role=message_role, message_text=message.content) -class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): - - def __init__(self, model: str, api_key: str): - self.client = OpenAI(api_key=api_key) - self.model = model +class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): def __init__(self, client, model: str): self.client = client @@ -32,7 +29,7 @@ def __init__(self, client, model: str): def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> any: + ) -> IrisMessage: response = self.client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), @@ -40,13 +37,23 @@ def chat_completion( max_tokens=arguments.max_tokens, stop=arguments.stop, ) - return response + return convert_to_iris_message(response.choices[0].message) + + +class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + + def __init__(self, model: str, api_key: str): + from openai import OpenAI + + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model) def __str__(self): return f"OpenAIChat('{self.model}')" -class AzureChatCompletionWrapper(OpenAIChatCompletionWrapper): +class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): def __init__( self, diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index 1d2eded3..94ba1ee2 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -5,11 +5,7 @@ from llm.wrapper import LlmCompletionWrapperInterface -class OpenAICompletionWrapper(LlmCompletionWrapperInterface): - - def __init__(self, model: str, api_key: str): - self.client = OpenAI(api_key=api_key) - self.model = model +class BaseOpenAICompletionWrapper(LlmCompletionWrapperInterface): def __init__(self, client, model: str): self.client = client @@ -25,11 +21,19 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: ) return response + +class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + + def __init__(self, model: str, api_key: str): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model) + def __str__(self): return f"OpenAICompletion('{self.model}')" -class AzureCompletionWrapper(OpenAICompletionWrapper): +class AzureCompletionWrapper(BaseOpenAICompletionWrapper): def __init__( self, diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 4983d3ee..726fb272 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -6,11 +6,7 @@ ) -class OpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): - - def __init__(self, model: str, api_key: str): - self.client = OpenAI(api_key=api_key) - self.model = model +class BaseOpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): def __init__(self, client, model: str): self.client = client @@ -24,11 +20,19 @@ def create_embedding(self, text: str) -> list[float]: ) return response.data[0].embedding + +class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + + def __init__(self, model: str, api_key: str): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model) + def __str__(self): return f"OpenAIEmbedding('{self.model}')" -class AzureEmbeddingWrapper(OpenAIEmbeddingWrapper): +class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): def __init__( self, From 1f924db8a6a48520ca69999bddd69141136c9148 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 19:53:54 +0100 Subject: [PATCH 06/12] message_text -> text --- app/domain/message.py | 10 ++++------ app/llm/wrapper/ollama_wrapper.py | 7 ++----- app/llm/wrapper/open_ai_chat_wrapper.py | 5 ++--- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/app/domain/message.py b/app/domain/message.py index 960750a9..b1f521cc 100644 --- a/app/domain/message.py +++ b/app/domain/message.py @@ -9,13 +9,11 @@ class IrisMessageRole(Enum): class IrisMessage: role: IrisMessageRole - message_text: str + text: str - def __init__(self, role: IrisMessageRole, message_text: str): + def __init__(self, role: IrisMessageRole, text: str): self.role = role - self.message_text = message_text + self.text = text def __str__(self): - return ( - f"IrisMessage(role={self.role.value}, message_text='{self.message_text}')" - ) + return f"IrisMessage(role={self.role.value}, text='{self.text}')" diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 8d0cd397..1d5fd3bf 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -11,15 +11,12 @@ def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: return [ - Message(role=message.role.value, content=message.message_text) - for message in messages + Message(role=message.role.value, content=message.text) for message in messages ] def convert_to_iris_message(message: Message) -> IrisMessage: - return IrisMessage( - role=IrisMessageRole(message["role"]), message_text=message["content"] - ) + return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"]) class OllamaWrapper( diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index be1016a0..2d1d45cf 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -10,15 +10,14 @@ def convert_to_open_ai_messages( messages: list[IrisMessage], ) -> list[ChatCompletionMessageParam]: return [ - {"role": message.role.value, "content": message.message_text} - for message in messages + {"role": message.role.value, "content": message.text} for message in messages ] def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: # Get IrisMessageRole from the string message.role message_role = IrisMessageRole(message.role) - return IrisMessage(role=message_role, message_text=message.content) + return IrisMessage(role=message_role, text=message.content) class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): From ea6a39741b9d71fcd6f08080caf310c22c431ea1 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sun, 11 Feb 2024 14:11:39 +0100 Subject: [PATCH 07/12] Improve LlmManager --- app/llm/basic_request_handler.py | 12 +++--- app/llm/llm_manager.py | 42 ++++++++++--------- app/llm/wrapper/__init__.py | 2 +- ...r_interface.py => abstract_llm_wrapper.py} | 29 ++++++++----- app/llm/wrapper/ollama_wrapper.py | 15 +++---- app/llm/wrapper/open_ai_chat_wrapper.py | 18 ++++---- app/llm/wrapper/open_ai_completion_wrapper.py | 14 ++++--- app/llm/wrapper/open_ai_embedding_wrapper.py | 14 ++++--- 8 files changed, 82 insertions(+), 64 deletions(-) rename app/llm/wrapper/{llm_wrapper_interface.py => abstract_llm_wrapper.py} (62%) diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index fbeacb76..761d0e21 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -2,9 +2,9 @@ from llm import LlmManager from llm import RequestHandlerInterface, CompletionArguments from llm.wrapper import ( - LlmCompletionWrapperInterface, - LlmChatCompletionWrapperInterface, - LlmEmbeddingWrapperInterface, + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, ) type BasicRequestHandlerModel = str @@ -20,7 +20,7 @@ def __init__(self, model: BasicRequestHandlerModel): def completion(self, prompt: str, arguments: CompletionArguments) -> str: llm = self.llm_manager.get_llm_by_id(self.model).llm - if isinstance(llm, LlmCompletionWrapperInterface): + if isinstance(llm, AbstractLlmCompletionWrapper): return llm.completion(prompt, arguments) else: raise NotImplementedError( @@ -31,7 +31,7 @@ def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: llm = self.llm_manager.get_llm_by_id(self.model).llm - if isinstance(llm, LlmChatCompletionWrapperInterface): + if isinstance(llm, AbstractLlmChatCompletionWrapper): return llm.chat_completion(messages, arguments) else: raise NotImplementedError( @@ -40,7 +40,7 @@ def chat_completion( def create_embedding(self, text: str) -> list[float]: llm = self.llm_manager.get_llm_by_id(self.model).llm - if isinstance(llm, LlmEmbeddingWrapperInterface): + if isinstance(llm, AbstractLlmEmbeddingWrapper): return llm.create_embedding(text) else: raise NotImplementedError( diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index bc6de680..f0711033 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -3,10 +3,10 @@ import yaml from common import Singleton -from llm.wrapper import LlmWrapperInterface +from llm.wrapper import AbstractLlmWrapper -def create_llm_wrapper(config: dict) -> LlmWrapperInterface: +def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: if config["type"] == "openai": from llm.wrapper import OpenAICompletionWrapper @@ -15,6 +15,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import AzureCompletionWrapper return AzureCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], endpoint=config["endpoint"], azure_deployment=config["azure_deployment"], @@ -25,12 +28,19 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import OpenAIChatCompletionWrapper return OpenAIChatCompletionWrapper( - model=config["model"], api_key=config["api_key"] + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + api_key=config["api_key"], ) elif config["type"] == "azure_chat": from llm.wrapper import AzureChatCompletionWrapper return AzureChatCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], endpoint=config["endpoint"], azure_deployment=config["azure_deployment"], @@ -45,6 +55,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import AzureEmbeddingWrapper return AzureEmbeddingWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], endpoint=config["endpoint"], azure_deployment=config["azure_deployment"], @@ -55,6 +68,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import OllamaWrapper return OllamaWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], host=config["host"], ) @@ -62,27 +78,15 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: raise Exception(f"Unknown LLM type: {config['type']}") -class LlmManagerEntry: - id: str - llm: LlmWrapperInterface - - def __init__(self, config: dict): - self.id = config["id"] - self.llm = create_llm_wrapper(config) - - def __str__(self): - return f"{self.id}: {self.llm}" - - class LlmManager(metaclass=Singleton): - llms: list[LlmManagerEntry] + entries: list[AbstractLlmWrapper] def __init__(self): - self.llms = [] + self.entries = [] self.load_llms() def get_llm_by_id(self, llm_id): - for llm in self.llms: + for llm in self.entries: if llm.id == llm_id: return llm @@ -94,4 +98,4 @@ def load_llms(self): with open(path, "r") as file: loaded_llms = yaml.safe_load(file) - self.llms = [LlmManagerEntry(llm) for llm in loaded_llms] + self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 4364afa0..7e0dabff 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,4 +1,4 @@ -from llm.wrapper.llm_wrapper_interface import * +from llm.wrapper.abstract_llm_wrapper import * from llm.wrapper.open_ai_completion_wrapper import * from llm.wrapper.open_ai_chat_wrapper import * from llm.wrapper.open_ai_embedding_wrapper import * diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/abstract_llm_wrapper.py similarity index 62% rename from app/llm/wrapper/llm_wrapper_interface.py rename to app/llm/wrapper/abstract_llm_wrapper.py index b1e79acb..6d5e353e 100644 --- a/app/llm/wrapper/llm_wrapper_interface.py +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -3,15 +3,22 @@ from domain import IrisMessage from llm import CompletionArguments -type LlmWrapperInterface = ( - LlmCompletionWrapperInterface - | LlmChatCompletionWrapperInterface - | LlmEmbeddingWrapperInterface -) +class AbstractLlmWrapper(metaclass=ABCMeta): + """Abstract class for the llm wrappers""" -class LlmCompletionWrapperInterface(metaclass=ABCMeta): - """Interface for the llm completion wrappers""" + id: str + name: str + description: str + + def __init__(self, id: str, name: str, description: str): + self.id = id + self.name = name + self.description = description + + +class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm completion wrappers""" @classmethod def __subclasshook__(cls, subclass): @@ -23,8 +30,8 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: raise NotImplementedError -class LlmChatCompletionWrapperInterface(metaclass=ABCMeta): - """Interface for the llm chat completion wrappers""" +class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm chat completion wrappers""" @classmethod def __subclasshook__(cls, subclass): @@ -40,8 +47,8 @@ def chat_completion( raise NotImplementedError -class LlmEmbeddingWrapperInterface(metaclass=ABCMeta): - """Interface for the llm embedding wrappers""" +class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm embedding wrappers""" @classmethod def __subclasshook__(cls, subclass): diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 1d5fd3bf..9ce8e94b 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -3,9 +3,9 @@ from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments from llm.wrapper import ( - LlmChatCompletionWrapperInterface, - LlmCompletionWrapperInterface, - LlmEmbeddingWrapperInterface, + AbstractLlmChatCompletionWrapper, + AbstractLlmCompletionWrapper, + AbstractLlmEmbeddingWrapper, ) @@ -20,12 +20,13 @@ def convert_to_iris_message(message: Message) -> IrisMessage: class OllamaWrapper( - LlmCompletionWrapperInterface, - LlmChatCompletionWrapperInterface, - LlmEmbeddingWrapperInterface, + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, ): - def __init__(self, model: str, host: str): + def __init__(self, model: str, host: str, **kwargs): + super().__init__(**kwargs) self.client = Client(host=host) # TODO: Add authentication (httpx auth?) self.model = model diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index 2d1d45cf..c6b68e25 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,9 +1,9 @@ from openai.lib.azure import AzureOpenAI -from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import LlmChatCompletionWrapperInterface +from llm.wrapper import AbstractLlmChatCompletionWrapper def convert_to_open_ai_messages( @@ -14,15 +14,16 @@ def convert_to_open_ai_messages( ] -def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: +def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: # Get IrisMessageRole from the string message.role message_role = IrisMessageRole(message.role) return IrisMessage(role=message_role, text=message.content) -class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): +class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): - def __init__(self, client, model: str): + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) self.client = client self.model = model @@ -41,12 +42,12 @@ def chat_completion( class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): - def __init__(self, model: str, api_key: str): + def __init__(self, model: str, api_key: str, **kwargs): from openai import OpenAI client = OpenAI(api_key=api_key) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"OpenAIChat('{self.model}')" @@ -61,6 +62,7 @@ def __init__( azure_deployment: str, api_version: str, api_key: str, + **kwargs, ): client = AzureOpenAI( azure_endpoint=endpoint, @@ -69,7 +71,7 @@ def __init__( api_key=api_key, ) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index 94ba1ee2..daac194a 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -2,12 +2,13 @@ from openai.lib.azure import AzureOpenAI from llm import CompletionArguments -from llm.wrapper import LlmCompletionWrapperInterface +from llm.wrapper import AbstractLlmCompletionWrapper -class BaseOpenAICompletionWrapper(LlmCompletionWrapperInterface): +class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): - def __init__(self, client, model: str): + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) self.client = client self.model = model @@ -24,10 +25,10 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): - def __init__(self, model: str, api_key: str): + def __init__(self, model: str, api_key: str, **kwargs): client = OpenAI(api_key=api_key) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"OpenAICompletion('{self.model}')" @@ -42,6 +43,7 @@ def __init__( azure_deployment: str, api_version: str, api_key: str, + **kwargs, ): client = AzureOpenAI( azure_endpoint=endpoint, @@ -50,7 +52,7 @@ def __init__( api_key=api_key, ) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 726fb272..88b425bd 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -2,13 +2,14 @@ from openai.lib.azure import AzureOpenAI from llm.wrapper import ( - LlmEmbeddingWrapperInterface, + AbstractLlmEmbeddingWrapper, ) -class BaseOpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): +class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): - def __init__(self, client, model: str): + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) self.client = client self.model = model @@ -23,10 +24,10 @@ def create_embedding(self, text: str) -> list[float]: class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): - def __init__(self, model: str, api_key: str): + def __init__(self, model: str, api_key: str, **kwargs): client = OpenAI(api_key=api_key) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"OpenAIEmbedding('{self.model}')" @@ -41,6 +42,7 @@ def __init__( azure_deployment: str, api_version: str, api_key: str, + **kwargs, ): client = AzureOpenAI( azure_endpoint=endpoint, @@ -49,7 +51,7 @@ def __init__( api_key=api_key, ) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"AzureEmbedding('{self.model}')" From a8e5860559845105a739150294a3865d58200b12 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sun, 11 Feb 2024 14:18:01 +0100 Subject: [PATCH 08/12] Small package structure improvement --- app/llm/__init__.py | 3 +-- app/llm/basic_request_handler.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/app/llm/__init__.py b/app/llm/__init__.py index 51227f24..33542f1c 100644 --- a/app/llm/__init__.py +++ b/app/llm/__init__.py @@ -1,4 +1,3 @@ -from llm.generation_arguments import CompletionArguments from llm.request_handler_interface import RequestHandlerInterface -from llm.llm_manager import LlmManager +from llm.generation_arguments import * from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 761d0e21..f348da1f 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -1,6 +1,6 @@ from domain import IrisMessage -from llm import LlmManager from llm import RequestHandlerInterface, CompletionArguments +from llm.llm_manager import LlmManager from llm.wrapper import ( AbstractLlmCompletionWrapper, AbstractLlmChatCompletionWrapper, From 02d01b45b73bf963c8f5b2b0917a4c70ef9a92db Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sun, 11 Feb 2024 16:36:08 +0100 Subject: [PATCH 09/12] Add TODO note --- app/llm/llm_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index f0711033..49a56f30 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -6,6 +6,7 @@ from llm.wrapper import AbstractLlmWrapper +# TODO: Replace with pydantic in a future PR def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: if config["type"] == "openai": from llm.wrapper import OpenAICompletionWrapper From b06d20f424dcb6c4fc9d96b140c00471b7122af3 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sun, 11 Feb 2024 17:53:55 +0100 Subject: [PATCH 10/12] Add pydantic for LLM config --- app/llm/basic_request_handler.py | 2 +- app/llm/llm_manager.py | 80 ++----------------- app/llm/wrapper/__init__.py | 27 ++++++- app/llm/wrapper/abstract_llm_wrapper.py | 8 +- app/llm/wrapper/ollama_wrapper.py | 20 +++-- app/llm/wrapper/open_ai_chat_wrapper.py | 51 +++++------- app/llm/wrapper/open_ai_completion_wrapper.py | 48 +++++------ app/llm/wrapper/open_ai_embedding_wrapper.py | 50 +++++------- requirements.txt | 1 + 9 files changed, 107 insertions(+), 180 deletions(-) diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index f348da1f..001d2dbb 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -1,7 +1,7 @@ from domain import IrisMessage from llm import RequestHandlerInterface, CompletionArguments from llm.llm_manager import LlmManager -from llm.wrapper import ( +from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmCompletionWrapper, AbstractLlmChatCompletionWrapper, AbstractLlmEmbeddingWrapper, diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index 49a56f30..af593d32 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -1,82 +1,16 @@ import os +from pydantic import BaseModel, Field + import yaml from common import Singleton -from llm.wrapper import AbstractLlmWrapper - - -# TODO: Replace with pydantic in a future PR -def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: - if config["type"] == "openai": - from llm.wrapper import OpenAICompletionWrapper - - return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) - elif config["type"] == "azure": - from llm.wrapper import AzureCompletionWrapper - - return AzureCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "openai_chat": - from llm.wrapper import OpenAIChatCompletionWrapper - - return OpenAIChatCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - api_key=config["api_key"], - ) - elif config["type"] == "azure_chat": - from llm.wrapper import AzureChatCompletionWrapper - - return AzureChatCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "openai_embedding": - from llm.wrapper import OpenAIEmbeddingWrapper - - return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) - elif config["type"] == "azure_embedding": - from llm.wrapper import AzureEmbeddingWrapper +from llm.wrapper import AbstractLlmWrapper, LlmWrapper - return AzureEmbeddingWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "ollama": - from llm.wrapper import OllamaWrapper - return OllamaWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - host=config["host"], - ) - else: - raise Exception(f"Unknown LLM type: {config['type']}") +# Small workaround to get pydantic discriminators working +class LlmList(BaseModel): + llms: list[LlmWrapper] = Field(discriminator="type") class LlmManager(metaclass=Singleton): @@ -99,4 +33,4 @@ def load_llms(self): with open(path, "r") as file: loaded_llms = yaml.safe_load(file) - self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] + self.entries = LlmList.parse_obj({"llms": loaded_llms}).llms diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 7e0dabff..c4807ec5 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,5 +1,24 @@ -from llm.wrapper.abstract_llm_wrapper import * -from llm.wrapper.open_ai_completion_wrapper import * -from llm.wrapper.open_ai_chat_wrapper import * -from llm.wrapper.open_ai_embedding_wrapper import * +from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper +from llm.wrapper.open_ai_completion_wrapper import ( + OpenAICompletionWrapper, + AzureCompletionWrapper, +) +from llm.wrapper.open_ai_chat_wrapper import ( + OpenAIChatCompletionWrapper, + AzureChatCompletionWrapper, +) +from llm.wrapper.open_ai_embedding_wrapper import ( + OpenAIEmbeddingWrapper, + AzureEmbeddingWrapper, +) from llm.wrapper.ollama_wrapper import OllamaWrapper + +type LlmWrapper = ( + OpenAICompletionWrapper + | AzureCompletionWrapper + | OpenAIChatCompletionWrapper + | AzureChatCompletionWrapper + | OpenAIEmbeddingWrapper + | AzureEmbeddingWrapper + | OllamaWrapper +) diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py index 6d5e353e..057b3aca 100644 --- a/app/llm/wrapper/abstract_llm_wrapper.py +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -1,21 +1,17 @@ from abc import ABCMeta, abstractmethod +from pydantic import BaseModel from domain import IrisMessage from llm import CompletionArguments -class AbstractLlmWrapper(metaclass=ABCMeta): +class AbstractLlmWrapper(BaseModel, metaclass=ABCMeta): """Abstract class for the llm wrappers""" id: str name: str description: str - def __init__(self, id: str, name: str, description: str): - self.id = id - self.name = name - self.description = description - class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): """Abstract class for the llm completion wrappers""" diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 9ce8e94b..4ea0e9b0 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,8 +1,10 @@ +from typing import Literal, Any + from ollama import Client, Message from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import ( +from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmChatCompletionWrapper, AbstractLlmCompletionWrapper, AbstractLlmEmbeddingWrapper, @@ -24,26 +26,28 @@ class OllamaWrapper( AbstractLlmChatCompletionWrapper, AbstractLlmEmbeddingWrapper, ): + type: Literal["ollama"] + model: str + host: str + _client: Client - def __init__(self, model: str, host: str, **kwargs): - super().__init__(**kwargs) - self.client = Client(host=host) # TODO: Add authentication (httpx auth?) - self.model = model + def model_post_init(self, __context: Any) -> None: + self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?) def completion(self, prompt: str, arguments: CompletionArguments) -> str: - response = self.client.generate(model=self.model, prompt=prompt) + response = self._client.generate(model=self.model, prompt=prompt) return response["response"] def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> any: - response = self.client.chat( + response = self._client.chat( model=self.model, messages=convert_to_ollama_messages(messages) ) return convert_to_iris_message(response["message"]) def create_embedding(self, text: str) -> list[float]: - response = self.client.embeddings(model=self.model, prompt=text) + response = self._client.embeddings(model=self.model, prompt=text) return list(response) def __str__(self): diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index c6b68e25..6a605ad5 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,9 +1,11 @@ +from typing import Literal, Any +from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import AbstractLlmChatCompletionWrapper +from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper def convert_to_open_ai_messages( @@ -21,16 +23,14 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> IrisMessage: - response = self.client.chat.completions.create( + response = self._client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), temperature=arguments.temperature, @@ -41,37 +41,28 @@ def chat_completion( class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + type: Literal["openai_chat"] - def __init__(self, model: str, api_key: str, **kwargs): - from openai import OpenAI - - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAIChat('{self.model}')" class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_chat"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index daac194a..22fe4ed2 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -1,19 +1,18 @@ +from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI from llm import CompletionArguments -from llm.wrapper import AbstractLlmCompletionWrapper +from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def completion(self, prompt: str, arguments: CompletionArguments) -> any: - response = self.client.completions.create( + response = self._client.completions.create( model=self.model, prompt=prompt, temperature=arguments.temperature, @@ -24,35 +23,28 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + type: Literal["openai_completion"] - def __init__(self, model: str, api_key: str, **kwargs): - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAICompletion('{self.model}')" class AzureCompletionWrapper(BaseOpenAICompletionWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_completion"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 88b425bd..99c397c9 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -1,20 +1,17 @@ +from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI -from llm.wrapper import ( - AbstractLlmEmbeddingWrapper, -) +from llm.wrapper.abstract_llm_wrapper import AbstractLlmEmbeddingWrapper class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def create_embedding(self, text: str) -> list[float]: - response = self.client.embeddings.create( + response = self._client.embeddings.create( model=self.model, input=text, encoding_format="float", @@ -23,35 +20,28 @@ def create_embedding(self, text: str) -> list[float]: class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + type: Literal["openai_embedding"] - def __init__(self, model: str, api_key: str, **kwargs): - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAIEmbedding('{self.model}')" class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_embedding"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureEmbedding('{self.model}')" diff --git a/requirements.txt b/requirements.txt index 0b8fabbb..71c9e37e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ uvicorn==0.23.1 black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 +pydantic==2.6.1 From 4af8a5f7b9b6fea33570b27e3eaddb00be080f0b Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sun, 11 Feb 2024 22:18:16 +0100 Subject: [PATCH 11/12] WIP: Add langchain chat model wrapper --- app/llm/langchain/__init__.py | 1 + .../langchain/iris_langchain_chat_model.py | 46 +++++++++++++++++++ app/llm/request_handler_interface.py | 2 +- requirements.txt | 1 + 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 app/llm/langchain/__init__.py create mode 100644 app/llm/langchain/iris_langchain_chat_model.py diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py new file mode 100644 index 00000000..1f75540b --- /dev/null +++ b/app/llm/langchain/__init__.py @@ -0,0 +1 @@ +from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py new file mode 100644 index 00000000..41f0df4e --- /dev/null +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -0,0 +1,46 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, +) +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult + +from domain import IrisMessage +from llm import RequestHandlerInterface, CompletionArguments + + +def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: + return BaseMessage(content=iris_message.text, role=iris_message.role) + + +def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: + return IrisMessage(text=base_message.content, role=base_message.role) + + +class IrisLangchainChatModel(BaseChatModel): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.request_handler = request_handler + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> ChatResult: + iris_message = self.request_handler.chat_completion( + messages, CompletionArguments(stop=stop) + ) + base_message = convert_iris_message_to_base_message(iris_message) + return ChatResult(generations=[base_message]) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py index 5c15df30..02a3f21c 100644 --- a/app/llm/request_handler_interface.py +++ b/app/llm/request_handler_interface.py @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: @abstractmethod def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> [IrisMessage]: + ) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index 71c9e37e..2796179f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 pydantic==2.6.1 +langchain==0.1.6 From 16b72b016f83cf46595b47ac044b0ab8ceb0dc75 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Wed, 14 Feb 2024 15:07:47 +0100 Subject: [PATCH 12/12] Add all required Langchain LLM wrappers --- app/llm/basic_request_handler.py | 6 +-- app/llm/langchain/__init__.py | 2 + .../langchain/iris_langchain_chat_model.py | 28 ++++++++++---- .../iris_langchain_completion_model.py | 37 +++++++++++++++++++ app/llm/langchain/iris_langchain_embedding.py | 20 ++++++++++ 5 files changed, 83 insertions(+), 10 deletions(-) create mode 100644 app/llm/langchain/iris_langchain_completion_model.py create mode 100644 app/llm/langchain/iris_langchain_embedding.py diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 001d2dbb..12079901 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel): self.llm_manager = LlmManager() def completion(self, prompt: str, arguments: CompletionArguments) -> str: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmCompletionWrapper): return llm.completion(prompt, arguments) else: @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmChatCompletionWrapper): return llm.chat_completion(messages, arguments) else: @@ -39,7 +39,7 @@ def chat_completion( ) def create_embedding(self, text: str) -> list[float]: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmEmbeddingWrapper): return llm.create_embedding(text) else: diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py index 1f75540b..4deb1372 100644 --- a/app/llm/langchain/__init__.py +++ b/app/llm/langchain/__init__.py @@ -1 +1,3 @@ +from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel +from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index 41f0df4e..d0e558fa 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -6,17 +6,30 @@ ) from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult +from langchain_core.outputs.chat_generation import ChatGeneration -from domain import IrisMessage +from domain import IrisMessage, IrisMessageRole from llm import RequestHandlerInterface, CompletionArguments def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: - return BaseMessage(content=iris_message.text, role=iris_message.role) + role_map = { + IrisMessageRole.USER: "human", + IrisMessageRole.ASSISTANT: "ai", + IrisMessageRole.SYSTEM: "system", + } + return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: - return IrisMessage(text=base_message.content, role=base_message.role) + role_map = { + "human": IrisMessageRole.USER, + "ai": IrisMessageRole.ASSISTANT, + "system": IrisMessageRole.SYSTEM, + } + return IrisMessage( + text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) + ) class IrisLangchainChatModel(BaseChatModel): @@ -25,8 +38,7 @@ class IrisLangchainChatModel(BaseChatModel): request_handler: RequestHandlerInterface def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.request_handler = request_handler + super().__init__(request_handler=request_handler, **kwargs) def _generate( self, @@ -35,11 +47,13 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> ChatResult: + iris_messages = [convert_base_message_to_iris_message(m) for m in messages] iris_message = self.request_handler.chat_completion( - messages, CompletionArguments(stop=stop) + iris_messages, CompletionArguments(stop=stop) ) base_message = convert_iris_message_to_base_message(iris_message) - return ChatResult(generations=[base_message]) + chat_generation = ChatGeneration(message=base_message) + return ChatResult(generations=[chat_generation]) @property def _llm_type(self) -> str: diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py new file mode 100644 index 00000000..8a8b2e95 --- /dev/null +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import LLMResult +from langchain_core.outputs.generation import Generation + +from llm import RequestHandlerInterface, CompletionArguments + + +class IrisLangchainCompletionModel(BaseLLM): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> LLMResult: + generations = [] + args = CompletionArguments(stop=stop) + for prompt in prompts: + completion = self.request_handler.completion( + prompt=prompt, arguments=args, **kwargs + ) + generations.append([Generation(text=completion)]) + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding.py new file mode 100644 index 00000000..01a840ff --- /dev/null +++ b/app/llm/langchain/iris_langchain_embedding.py @@ -0,0 +1,20 @@ +from typing import List, Any + +from langchain_core.embeddings import Embeddings + +from llm import RequestHandlerInterface + + +class IrisLangchainEmbeddingModel(Embeddings): + """Custom langchain embedding for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + return self.request_handler.create_embedding(text)