From ea6a39741b9d71fcd6f08080caf310c22c431ea1 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sun, 11 Feb 2024 14:11:39 +0100 Subject: [PATCH] Improve LlmManager --- app/llm/basic_request_handler.py | 12 +++--- app/llm/llm_manager.py | 42 ++++++++++--------- app/llm/wrapper/__init__.py | 2 +- ...r_interface.py => abstract_llm_wrapper.py} | 29 ++++++++----- app/llm/wrapper/ollama_wrapper.py | 15 +++---- app/llm/wrapper/open_ai_chat_wrapper.py | 18 ++++---- app/llm/wrapper/open_ai_completion_wrapper.py | 14 ++++--- app/llm/wrapper/open_ai_embedding_wrapper.py | 14 ++++--- 8 files changed, 82 insertions(+), 64 deletions(-) rename app/llm/wrapper/{llm_wrapper_interface.py => abstract_llm_wrapper.py} (62%) diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index fbeacb76..761d0e21 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -2,9 +2,9 @@ from llm import LlmManager from llm import RequestHandlerInterface, CompletionArguments from llm.wrapper import ( - LlmCompletionWrapperInterface, - LlmChatCompletionWrapperInterface, - LlmEmbeddingWrapperInterface, + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, ) type BasicRequestHandlerModel = str @@ -20,7 +20,7 @@ def __init__(self, model: BasicRequestHandlerModel): def completion(self, prompt: str, arguments: CompletionArguments) -> str: llm = self.llm_manager.get_llm_by_id(self.model).llm - if isinstance(llm, LlmCompletionWrapperInterface): + if isinstance(llm, AbstractLlmCompletionWrapper): return llm.completion(prompt, arguments) else: raise NotImplementedError( @@ -31,7 +31,7 @@ def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: llm = self.llm_manager.get_llm_by_id(self.model).llm - if isinstance(llm, LlmChatCompletionWrapperInterface): + if isinstance(llm, AbstractLlmChatCompletionWrapper): return llm.chat_completion(messages, arguments) else: raise NotImplementedError( @@ -40,7 +40,7 @@ def chat_completion( def create_embedding(self, text: str) -> list[float]: llm = self.llm_manager.get_llm_by_id(self.model).llm - if isinstance(llm, LlmEmbeddingWrapperInterface): + if isinstance(llm, AbstractLlmEmbeddingWrapper): return llm.create_embedding(text) else: raise NotImplementedError( diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index bc6de680..f0711033 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -3,10 +3,10 @@ import yaml from common import Singleton -from llm.wrapper import LlmWrapperInterface +from llm.wrapper import AbstractLlmWrapper -def create_llm_wrapper(config: dict) -> LlmWrapperInterface: +def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: if config["type"] == "openai": from llm.wrapper import OpenAICompletionWrapper @@ -15,6 +15,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import AzureCompletionWrapper return AzureCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], endpoint=config["endpoint"], azure_deployment=config["azure_deployment"], @@ -25,12 +28,19 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import OpenAIChatCompletionWrapper return OpenAIChatCompletionWrapper( - model=config["model"], api_key=config["api_key"] + id=config["id"], + name=config["name"], + description=config["description"], + model=config["model"], + api_key=config["api_key"], ) elif config["type"] == "azure_chat": from llm.wrapper import AzureChatCompletionWrapper return AzureChatCompletionWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], endpoint=config["endpoint"], azure_deployment=config["azure_deployment"], @@ -45,6 +55,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import AzureEmbeddingWrapper return AzureEmbeddingWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], endpoint=config["endpoint"], azure_deployment=config["azure_deployment"], @@ -55,6 +68,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: from llm.wrapper import OllamaWrapper return OllamaWrapper( + id=config["id"], + name=config["name"], + description=config["description"], model=config["model"], host=config["host"], ) @@ -62,27 +78,15 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface: raise Exception(f"Unknown LLM type: {config['type']}") -class LlmManagerEntry: - id: str - llm: LlmWrapperInterface - - def __init__(self, config: dict): - self.id = config["id"] - self.llm = create_llm_wrapper(config) - - def __str__(self): - return f"{self.id}: {self.llm}" - - class LlmManager(metaclass=Singleton): - llms: list[LlmManagerEntry] + entries: list[AbstractLlmWrapper] def __init__(self): - self.llms = [] + self.entries = [] self.load_llms() def get_llm_by_id(self, llm_id): - for llm in self.llms: + for llm in self.entries: if llm.id == llm_id: return llm @@ -94,4 +98,4 @@ def load_llms(self): with open(path, "r") as file: loaded_llms = yaml.safe_load(file) - self.llms = [LlmManagerEntry(llm) for llm in loaded_llms] + self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 4364afa0..7e0dabff 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,4 +1,4 @@ -from llm.wrapper.llm_wrapper_interface import * +from llm.wrapper.abstract_llm_wrapper import * from llm.wrapper.open_ai_completion_wrapper import * from llm.wrapper.open_ai_chat_wrapper import * from llm.wrapper.open_ai_embedding_wrapper import * diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/abstract_llm_wrapper.py similarity index 62% rename from app/llm/wrapper/llm_wrapper_interface.py rename to app/llm/wrapper/abstract_llm_wrapper.py index b1e79acb..6d5e353e 100644 --- a/app/llm/wrapper/llm_wrapper_interface.py +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -3,15 +3,22 @@ from domain import IrisMessage from llm import CompletionArguments -type LlmWrapperInterface = ( - LlmCompletionWrapperInterface - | LlmChatCompletionWrapperInterface - | LlmEmbeddingWrapperInterface -) +class AbstractLlmWrapper(metaclass=ABCMeta): + """Abstract class for the llm wrappers""" -class LlmCompletionWrapperInterface(metaclass=ABCMeta): - """Interface for the llm completion wrappers""" + id: str + name: str + description: str + + def __init__(self, id: str, name: str, description: str): + self.id = id + self.name = name + self.description = description + + +class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm completion wrappers""" @classmethod def __subclasshook__(cls, subclass): @@ -23,8 +30,8 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: raise NotImplementedError -class LlmChatCompletionWrapperInterface(metaclass=ABCMeta): - """Interface for the llm chat completion wrappers""" +class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm chat completion wrappers""" @classmethod def __subclasshook__(cls, subclass): @@ -40,8 +47,8 @@ def chat_completion( raise NotImplementedError -class LlmEmbeddingWrapperInterface(metaclass=ABCMeta): - """Interface for the llm embedding wrappers""" +class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta): + """Abstract class for the llm embedding wrappers""" @classmethod def __subclasshook__(cls, subclass): diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 1d5fd3bf..9ce8e94b 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -3,9 +3,9 @@ from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments from llm.wrapper import ( - LlmChatCompletionWrapperInterface, - LlmCompletionWrapperInterface, - LlmEmbeddingWrapperInterface, + AbstractLlmChatCompletionWrapper, + AbstractLlmCompletionWrapper, + AbstractLlmEmbeddingWrapper, ) @@ -20,12 +20,13 @@ def convert_to_iris_message(message: Message) -> IrisMessage: class OllamaWrapper( - LlmCompletionWrapperInterface, - LlmChatCompletionWrapperInterface, - LlmEmbeddingWrapperInterface, + AbstractLlmCompletionWrapper, + AbstractLlmChatCompletionWrapper, + AbstractLlmEmbeddingWrapper, ): - def __init__(self, model: str, host: str): + def __init__(self, model: str, host: str, **kwargs): + super().__init__(**kwargs) self.client = Client(host=host) # TODO: Add authentication (httpx auth?) self.model = model diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index 2d1d45cf..c6b68e25 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,9 +1,9 @@ from openai.lib.azure import AzureOpenAI -from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import LlmChatCompletionWrapperInterface +from llm.wrapper import AbstractLlmChatCompletionWrapper def convert_to_open_ai_messages( @@ -14,15 +14,16 @@ def convert_to_open_ai_messages( ] -def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: +def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: # Get IrisMessageRole from the string message.role message_role = IrisMessageRole(message.role) return IrisMessage(role=message_role, text=message.content) -class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): +class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): - def __init__(self, client, model: str): + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) self.client = client self.model = model @@ -41,12 +42,12 @@ def chat_completion( class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): - def __init__(self, model: str, api_key: str): + def __init__(self, model: str, api_key: str, **kwargs): from openai import OpenAI client = OpenAI(api_key=api_key) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"OpenAIChat('{self.model}')" @@ -61,6 +62,7 @@ def __init__( azure_deployment: str, api_version: str, api_key: str, + **kwargs, ): client = AzureOpenAI( azure_endpoint=endpoint, @@ -69,7 +71,7 @@ def __init__( api_key=api_key, ) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index 94ba1ee2..daac194a 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -2,12 +2,13 @@ from openai.lib.azure import AzureOpenAI from llm import CompletionArguments -from llm.wrapper import LlmCompletionWrapperInterface +from llm.wrapper import AbstractLlmCompletionWrapper -class BaseOpenAICompletionWrapper(LlmCompletionWrapperInterface): +class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): - def __init__(self, client, model: str): + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) self.client = client self.model = model @@ -24,10 +25,10 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): - def __init__(self, model: str, api_key: str): + def __init__(self, model: str, api_key: str, **kwargs): client = OpenAI(api_key=api_key) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"OpenAICompletion('{self.model}')" @@ -42,6 +43,7 @@ def __init__( azure_deployment: str, api_version: str, api_key: str, + **kwargs, ): client = AzureOpenAI( azure_endpoint=endpoint, @@ -50,7 +52,7 @@ def __init__( api_key=api_key, ) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 726fb272..88b425bd 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -2,13 +2,14 @@ from openai.lib.azure import AzureOpenAI from llm.wrapper import ( - LlmEmbeddingWrapperInterface, + AbstractLlmEmbeddingWrapper, ) -class BaseOpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): +class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): - def __init__(self, client, model: str): + def __init__(self, client, model: str, **kwargs): + super().__init__(**kwargs) self.client = client self.model = model @@ -23,10 +24,10 @@ def create_embedding(self, text: str) -> list[float]: class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): - def __init__(self, model: str, api_key: str): + def __init__(self, model: str, api_key: str, **kwargs): client = OpenAI(api_key=api_key) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"OpenAIEmbedding('{self.model}')" @@ -41,6 +42,7 @@ def __init__( azure_deployment: str, api_version: str, api_key: str, + **kwargs, ): client = AzureOpenAI( azure_endpoint=endpoint, @@ -49,7 +51,7 @@ def __init__( api_key=api_key, ) model = model - super().__init__(client, model) + super().__init__(client, model, **kwargs) def __str__(self): return f"AzureEmbedding('{self.model}')"