From fbb403a83d044779cdb2d2a047e0e48bfb98d62c Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Mon, 12 Feb 2024 11:46:50 +0100 Subject: [PATCH] `LLM`: Use pydantic for config parsing (#52) --- app/llm/basic_request_handler.py | 2 +- app/llm/llm_manager.py | 80 ++----------------- app/llm/wrapper/__init__.py | 27 ++++++- app/llm/wrapper/abstract_llm_wrapper.py | 8 +- app/llm/wrapper/ollama_wrapper.py | 20 +++-- app/llm/wrapper/open_ai_chat_wrapper.py | 51 +++++------- app/llm/wrapper/open_ai_completion_wrapper.py | 48 +++++------ app/llm/wrapper/open_ai_embedding_wrapper.py | 50 +++++------- requirements.txt | 1 + 9 files changed, 107 insertions(+), 180 deletions(-) diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index f348da1f..001d2dbb 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -1,7 +1,7 @@ from domain import IrisMessage from llm import RequestHandlerInterface, CompletionArguments from llm.llm_manager import LlmManager -from llm.wrapper import ( +from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmCompletionWrapper, AbstractLlmChatCompletionWrapper, AbstractLlmEmbeddingWrapper, diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index 49a56f30..af593d32 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -1,82 +1,16 @@ import os +from pydantic import BaseModel, Field + import yaml from common import Singleton -from llm.wrapper import AbstractLlmWrapper - - -# TODO: Replace with pydantic in a future PR -def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: - if config["type"] == "openai": - from llm.wrapper import OpenAICompletionWrapper - - return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) - elif config["type"] == "azure": - from llm.wrapper import AzureCompletionWrapper - - return AzureCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "openai_chat": - from llm.wrapper import OpenAIChatCompletionWrapper - - return OpenAIChatCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - api_key=config["api_key"], - ) - elif config["type"] == "azure_chat": - from llm.wrapper import AzureChatCompletionWrapper - - return AzureChatCompletionWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "openai_embedding": - from llm.wrapper import OpenAIEmbeddingWrapper - - return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) - elif config["type"] == "azure_embedding": - from llm.wrapper import AzureEmbeddingWrapper +from llm.wrapper import AbstractLlmWrapper, LlmWrapper - return AzureEmbeddingWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - endpoint=config["endpoint"], - azure_deployment=config["azure_deployment"], - api_version=config["api_version"], - api_key=config["api_key"], - ) - elif config["type"] == "ollama": - from llm.wrapper import OllamaWrapper - return OllamaWrapper( - id=config["id"], - name=config["name"], - description=config["description"], - model=config["model"], - host=config["host"], - ) - else: - raise Exception(f"Unknown LLM type: {config['type']}") +# Small workaround to get pydantic discriminators working +class LlmList(BaseModel): + llms: list[LlmWrapper] = Field(discriminator="type") class LlmManager(metaclass=Singleton): @@ -99,4 +33,4 @@ def load_llms(self): with open(path, "r") as file: loaded_llms = yaml.safe_load(file) - self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] + self.entries = LlmList.parse_obj({"llms": loaded_llms}).llms diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 7e0dabff..c4807ec5 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,5 +1,24 @@ -from llm.wrapper.abstract_llm_wrapper import * -from llm.wrapper.open_ai_completion_wrapper import * -from llm.wrapper.open_ai_chat_wrapper import * -from llm.wrapper.open_ai_embedding_wrapper import * +from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper +from llm.wrapper.open_ai_completion_wrapper import ( + OpenAICompletionWrapper, + AzureCompletionWrapper, +) +from llm.wrapper.open_ai_chat_wrapper import ( + OpenAIChatCompletionWrapper, + AzureChatCompletionWrapper, +) +from llm.wrapper.open_ai_embedding_wrapper import ( + OpenAIEmbeddingWrapper, + AzureEmbeddingWrapper, +) from llm.wrapper.ollama_wrapper import OllamaWrapper + +type LlmWrapper = ( + OpenAICompletionWrapper + | AzureCompletionWrapper + | OpenAIChatCompletionWrapper + | AzureChatCompletionWrapper + | OpenAIEmbeddingWrapper + | AzureEmbeddingWrapper + | OllamaWrapper +) diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py index 6d5e353e..057b3aca 100644 --- a/app/llm/wrapper/abstract_llm_wrapper.py +++ b/app/llm/wrapper/abstract_llm_wrapper.py @@ -1,21 +1,17 @@ from abc import ABCMeta, abstractmethod +from pydantic import BaseModel from domain import IrisMessage from llm import CompletionArguments -class AbstractLlmWrapper(metaclass=ABCMeta): +class AbstractLlmWrapper(BaseModel, metaclass=ABCMeta): """Abstract class for the llm wrappers""" id: str name: str description: str - def __init__(self, id: str, name: str, description: str): - self.id = id - self.name = name - self.description = description - class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): """Abstract class for the llm completion wrappers""" diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 9ce8e94b..4ea0e9b0 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,8 +1,10 @@ +from typing import Literal, Any + from ollama import Client, Message from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import ( +from llm.wrapper.abstract_llm_wrapper import ( AbstractLlmChatCompletionWrapper, AbstractLlmCompletionWrapper, AbstractLlmEmbeddingWrapper, @@ -24,26 +26,28 @@ class OllamaWrapper( AbstractLlmChatCompletionWrapper, AbstractLlmEmbeddingWrapper, ): + type: Literal["ollama"] + model: str + host: str + _client: Client - def __init__(self, model: str, host: str, **kwargs): - super().__init__(**kwargs) - self.client = Client(host=host) # TODO: Add authentication (httpx auth?) - self.model = model + def model_post_init(self, __context: Any) -> None: + self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?) def completion(self, prompt: str, arguments: CompletionArguments) -> str: - response = self.client.generate(model=self.model, prompt=prompt) + response = self._client.generate(model=self.model, prompt=prompt) return response["response"] def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> any: - response = self.client.chat( + response = self._client.chat( model=self.model, messages=convert_to_ollama_messages(messages) ) return convert_to_iris_message(response["message"]) def create_embedding(self, text: str) -> list[float]: - response = self.client.embeddings(model=self.model, prompt=text) + response = self._client.embeddings(model=self.model, prompt=text) return list(response) def __str__(self): diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index c6b68e25..6a605ad5 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,9 +1,11 @@ +from typing import Literal, Any +from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper import AbstractLlmChatCompletionWrapper +from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper def convert_to_open_ai_messages( @@ -21,16 +23,14 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def chat_completion( self, messages: list[any], arguments: CompletionArguments ) -> IrisMessage: - response = self.client.chat.completions.create( + response = self._client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), temperature=arguments.temperature, @@ -41,37 +41,28 @@ def chat_completion( class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + type: Literal["openai_chat"] - def __init__(self, model: str, api_key: str, **kwargs): - from openai import OpenAI - - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAIChat('{self.model}')" class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_chat"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureChat('{self.model}')" diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index daac194a..22fe4ed2 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -1,19 +1,18 @@ +from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI from llm import CompletionArguments -from llm.wrapper import AbstractLlmCompletionWrapper +from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def completion(self, prompt: str, arguments: CompletionArguments) -> any: - response = self.client.completions.create( + response = self._client.completions.create( model=self.model, prompt=prompt, temperature=arguments.temperature, @@ -24,35 +23,28 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + type: Literal["openai_completion"] - def __init__(self, model: str, api_key: str, **kwargs): - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAICompletion('{self.model}')" class AzureCompletionWrapper(BaseOpenAICompletionWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_completion"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureCompletion('{self.model}')" diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 88b425bd..99c397c9 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -1,20 +1,17 @@ +from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI -from llm.wrapper import ( - AbstractLlmEmbeddingWrapper, -) +from llm.wrapper.abstract_llm_wrapper import AbstractLlmEmbeddingWrapper class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): - - def __init__(self, client, model: str, **kwargs): - super().__init__(**kwargs) - self.client = client - self.model = model + model: str + api_key: str + _client: OpenAI def create_embedding(self, text: str) -> list[float]: - response = self.client.embeddings.create( + response = self._client.embeddings.create( model=self.model, input=text, encoding_format="float", @@ -23,35 +20,28 @@ def create_embedding(self, text: str) -> list[float]: class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + type: Literal["openai_embedding"] - def __init__(self, model: str, api_key: str, **kwargs): - client = OpenAI(api_key=api_key) - model = model - super().__init__(client, model, **kwargs) + def model_post_init(self, __context: Any) -> None: + self._client = OpenAI(api_key=self.api_key) def __str__(self): return f"OpenAIEmbedding('{self.model}')" class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): - - def __init__( - self, - model: str, - endpoint: str, - azure_deployment: str, - api_version: str, - api_key: str, - **kwargs, - ): - client = AzureOpenAI( - azure_endpoint=endpoint, - azure_deployment=azure_deployment, - api_version=api_version, - api_key=api_key, + type: Literal["azure_embedding"] + endpoint: str + azure_deployment: str + api_version: str + + def model_post_init(self, __context: Any) -> None: + self._client = AzureOpenAI( + azure_endpoint=self.endpoint, + azure_deployment=self.azure_deployment, + api_version=self.api_version, + api_key=self.api_key, ) - model = model - super().__init__(client, model, **kwargs) def __str__(self): return f"AzureEmbedding('{self.model}')" diff --git a/requirements.txt b/requirements.txt index 0b8fabbb..71c9e37e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ uvicorn==0.23.1 black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 +pydantic==2.6.1