From 14ae279976279fd5540323bf9fc7867ae3e4cfaa Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 10 Feb 2024 17:32:49 +0100 Subject: [PATCH] Small refinements --- .flake8 | 3 ++ app/common/__init__.py | 2 +- app/domain/__init__.py | 2 +- app/llm/__init__.py | 8 +-- app/llm/basic_request_handler.py | 12 +++-- app/llm/wrapper/__init__.py | 6 +-- app/llm/wrapper/azure_chat_wrapper.py | 32 ------------ app/llm/wrapper/ollama_wrapper.py | 7 ++- app/llm/wrapper/open_ai_chat_wrapper.py | 34 ++++++++++++ app/llm/wrapper/open_ai_completion_wrapper.py | 52 +++++++++++++++++++ app/llm/wrapper/open_ai_embedding_wrapper.py | 51 ++++++++++++++++++ 11 files changed, 163 insertions(+), 46 deletions(-) delete mode 100644 app/llm/wrapper/azure_chat_wrapper.py create mode 100644 app/llm/wrapper/open_ai_completion_wrapper.py create mode 100644 app/llm/wrapper/open_ai_embedding_wrapper.py diff --git a/.flake8 b/.flake8 index 18f3c3b8..a5c48807 100644 --- a/.flake8 +++ b/.flake8 @@ -7,4 +7,7 @@ exclude = per-file-ignores = # imported but unused __init__.py: F401, F403 + open_ai_chat_wrapper.py: F811 + open_ai_completion_wrapper.py: F811 + open_ai_embedding_wrapper.py: F811 diff --git a/app/common/__init__.py b/app/common/__init__.py index e190c8ba..97e30c68 100644 --- a/app/common/__init__.py +++ b/app/common/__init__.py @@ -1 +1 @@ -from singleton import Singleton +from common.singleton import Singleton diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 5ead9e65..270c228a 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -1 +1 @@ -from message import IrisMessage +from domain.message import IrisMessage diff --git a/app/llm/__init__.py b/app/llm/__init__.py index 73b39dc6..51227f24 100644 --- a/app/llm/__init__.py +++ b/app/llm/__init__.py @@ -1,4 +1,4 @@ -from generation_arguments import CompletionArguments -from request_handler_interface import RequestHandlerInterface -from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel -from llm_manager import LlmManager +from llm.generation_arguments import CompletionArguments +from llm.request_handler_interface import RequestHandlerInterface +from llm.llm_manager import LlmManager +from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index a572c151..fbeacb76 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -23,7 +23,9 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: if isinstance(llm, LlmCompletionWrapperInterface): return llm.completion(prompt, arguments) else: - raise NotImplementedError + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support completion" + ) def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments @@ -32,11 +34,15 @@ def chat_completion( if isinstance(llm, LlmChatCompletionWrapperInterface): return llm.chat_completion(messages, arguments) else: - raise NotImplementedError + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support chat completion" + ) def create_embedding(self, text: str) -> list[float]: llm = self.llm_manager.get_llm_by_id(self.model).llm if isinstance(llm, LlmEmbeddingWrapperInterface): return llm.create_embedding(text) else: - raise NotImplementedError + raise NotImplementedError( + f"The LLM {llm.__str__()} does not support embedding" + ) diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 6aef2481..6ddf4569 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,3 +1,3 @@ -from llm_wrapper_interface import * -from open_ai_chat_wrapper import * -from ollama_wrapper import OllamaWrapper +from llm.wrapper.llm_wrapper_interface import * +from llm.wrapper.open_ai_chat_wrapper import * +from llm.wrapper.ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/azure_chat_wrapper.py b/app/llm/wrapper/azure_chat_wrapper.py deleted file mode 100644 index 9022e3ef..00000000 --- a/app/llm/wrapper/azure_chat_wrapper.py +++ /dev/null @@ -1,32 +0,0 @@ -from openai.lib.azure import AzureOpenAI - -from llm import CompletionArguments -from llm.wrapper import LlmChatCompletionWrapperInterface, convert_to_open_ai_messages - - -class AzureChatCompletionWrapper(LlmChatCompletionWrapperInterface): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - ): - self.client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, - ) - self.model = model - - def chat_completion( - self, messages: list[any], arguments: CompletionArguments - ) -> any: - response = self.client.chat.completions.create( - model=self.model, - messages=convert_to_open_ai_messages(messages), - ) - return response diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 5ca682b8..9dc8131a 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -16,7 +16,7 @@ def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: def convert_to_iris_message(message: Message) -> IrisMessage: - return IrisMessage(role=message.role, message_text=message.content) + return IrisMessage(role=message["role"], message_text=message["content"]) class OllamaWrapper( @@ -43,4 +43,7 @@ def chat_completion( def create_embedding(self, text: str) -> list[float]: response = self.client.embeddings(model=self.model, prompt=text) - return response + return list(response) + + def __str__(self): + return f"Ollama('{self.model}')" diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index db04de18..a21a9951 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,4 +1,5 @@ from openai import OpenAI +from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam from domain import IrisMessage @@ -25,11 +26,44 @@ def __init__(self, model: str, api_key: str): self.client = OpenAI(api_key=api_key) self.model = model + def __init__(self, client, model: str): + self.client = client + self.model = model + def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> any: response = self.client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, ) return response + + def __str__(self): + return f"OpenAIChat('{self.model}')" + + +class AzureChatCompletionWrapper(OpenAIChatCompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model) + + def __str__(self): + return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py new file mode 100644 index 00000000..1d2eded3 --- /dev/null +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -0,0 +1,52 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm import CompletionArguments +from llm.wrapper import LlmCompletionWrapperInterface + + +class OpenAICompletionWrapper(LlmCompletionWrapperInterface): + + def __init__(self, model: str, api_key: str): + self.client = OpenAI(api_key=api_key) + self.model = model + + def __init__(self, client, model: str): + self.client = client + self.model = model + + def completion(self, prompt: str, arguments: CompletionArguments) -> any: + response = self.client.completions.create( + model=self.model, + prompt=prompt, + temperature=arguments.temperature, + max_tokens=arguments.max_tokens, + stop=arguments.stop, + ) + return response + + def __str__(self): + return f"OpenAICompletion('{self.model}')" + + +class AzureCompletionWrapper(OpenAICompletionWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model) + + def __str__(self): + return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py new file mode 100644 index 00000000..4983d3ee --- /dev/null +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -0,0 +1,51 @@ +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from llm.wrapper import ( + LlmEmbeddingWrapperInterface, +) + + +class OpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): + + def __init__(self, model: str, api_key: str): + self.client = OpenAI(api_key=api_key) + self.model = model + + def __init__(self, client, model: str): + self.client = client + self.model = model + + def create_embedding(self, text: str) -> list[float]: + response = self.client.embeddings.create( + model=self.model, + input=text, + encoding_format="float", + ) + return response.data[0].embedding + + def __str__(self): + return f"OpenAIEmbedding('{self.model}')" + + +class AzureEmbeddingWrapper(OpenAIEmbeddingWrapper): + + def __init__( + self, + model: str, + endpoint: str, + azure_deployment: str, + api_version: str, + api_key: str, + ): + client = AzureOpenAI( + azure_endpoint=endpoint, + azure_deployment=azure_deployment, + api_version=api_version, + api_key=api_key, + ) + model = model + super().__init__(client, model) + + def __str__(self): + return f"AzureEmbedding('{self.model}')"