From 626a92c1f5492df53780e82ab8bebbb1403b0357 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Wed, 14 Feb 2024 17:24:27 +0100 Subject: [PATCH 1/2] `LLM`: Add Langchain LLM wrapper classes (#57) --- app/llm/basic_request_handler.py | 6 +- app/llm/langchain/__init__.py | 3 + .../langchain/iris_langchain_chat_model.py | 60 +++++++++++++++++++ .../iris_langchain_completion_model.py | 37 ++++++++++++ app/llm/langchain/iris_langchain_embedding.py | 20 +++++++ app/llm/request_handler_interface.py | 2 +- requirements.txt | 1 + 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 app/llm/langchain/__init__.py create mode 100644 app/llm/langchain/iris_langchain_chat_model.py create mode 100644 app/llm/langchain/iris_langchain_completion_model.py create mode 100644 app/llm/langchain/iris_langchain_embedding.py diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 001d2dbb..12079901 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel): self.llm_manager = LlmManager() def completion(self, prompt: str, arguments: CompletionArguments) -> str: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmCompletionWrapper): return llm.completion(prompt, arguments) else: @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmChatCompletionWrapper): return llm.chat_completion(messages, arguments) else: @@ -39,7 +39,7 @@ def chat_completion( ) def create_embedding(self, text: str) -> list[float]: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmEmbeddingWrapper): return llm.create_embedding(text) else: diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py new file mode 100644 index 00000000..4deb1372 --- /dev/null +++ b/app/llm/langchain/__init__.py @@ -0,0 +1,3 @@ +from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel +from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel +from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py new file mode 100644 index 00000000..d0e558fa --- /dev/null +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -0,0 +1,60 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, +) +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_core.outputs.chat_generation import ChatGeneration + +from domain import IrisMessage, IrisMessageRole +from llm import RequestHandlerInterface, CompletionArguments + + +def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: + role_map = { + IrisMessageRole.USER: "human", + IrisMessageRole.ASSISTANT: "ai", + IrisMessageRole.SYSTEM: "system", + } + return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) + + +def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: + role_map = { + "human": IrisMessageRole.USER, + "ai": IrisMessageRole.ASSISTANT, + "system": IrisMessageRole.SYSTEM, + } + return IrisMessage( + text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) + ) + + +class IrisLangchainChatModel(BaseChatModel): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> ChatResult: + iris_messages = [convert_base_message_to_iris_message(m) for m in messages] + iris_message = self.request_handler.chat_completion( + iris_messages, CompletionArguments(stop=stop) + ) + base_message = convert_iris_message_to_base_message(iris_message) + chat_generation = ChatGeneration(message=base_message) + return ChatResult(generations=[chat_generation]) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py new file mode 100644 index 00000000..8a8b2e95 --- /dev/null +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import LLMResult +from langchain_core.outputs.generation import Generation + +from llm import RequestHandlerInterface, CompletionArguments + + +class IrisLangchainCompletionModel(BaseLLM): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> LLMResult: + generations = [] + args = CompletionArguments(stop=stop) + for prompt in prompts: + completion = self.request_handler.completion( + prompt=prompt, arguments=args, **kwargs + ) + generations.append([Generation(text=completion)]) + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding.py new file mode 100644 index 00000000..01a840ff --- /dev/null +++ b/app/llm/langchain/iris_langchain_embedding.py @@ -0,0 +1,20 @@ +from typing import List, Any + +from langchain_core.embeddings import Embeddings + +from llm import RequestHandlerInterface + + +class IrisLangchainEmbeddingModel(Embeddings): + """Custom langchain embedding for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + return self.request_handler.create_embedding(text) diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py index 5c15df30..02a3f21c 100644 --- a/app/llm/request_handler_interface.py +++ b/app/llm/request_handler_interface.py @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: @abstractmethod def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> [IrisMessage]: + ) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index 3b4afc16..f82b5836 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ black==24.1.1 flake8==7.0.0 pre-commit==3.6.1 pydantic==2.6.1 +langchain==0.1.6 From 5fc091bc00227d2e0dafe6e037a82bf5c52d77e6 Mon Sep 17 00:00:00 2001 From: Michael Dyer <59163924+MichaelOwenDyer@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:05:29 +0100 Subject: [PATCH 2/2] Refactoring suggestions (#59) --- app/llm/__init__.py | 6 +- app/llm/basic_request_handler.py | 50 +++++----------- ...n_arguments.py => completion_arguments.py} | 0 app/llm/external/__init__.py | 21 +++++++ app/llm/external/model.py | 60 +++++++++++++++++++ .../ollama_wrapper.py => external/ollama.py} | 24 ++++---- .../openai_chat.py} | 13 ++-- .../openai_completion.py} | 10 ++-- .../openai_embeddings.py} | 10 ++-- app/llm/langchain/__init__.py | 2 +- .../langchain/iris_langchain_chat_model.py | 8 +-- .../iris_langchain_completion_model.py | 10 ++-- ...g.py => iris_langchain_embedding_model.py} | 8 +-- app/llm/llm_manager.py | 6 +- app/llm/request_handler_interface.py | 26 ++++---- app/llm/wrapper/__init__.py | 24 -------- app/llm/wrapper/abstract_llm_wrapper.py | 58 ------------------ requirements.txt | 1 + 18 files changed, 154 insertions(+), 183 deletions(-) rename app/llm/{generation_arguments.py => completion_arguments.py} (100%) create mode 100644 app/llm/external/__init__.py create mode 100644 app/llm/external/model.py rename app/llm/{wrapper/ollama_wrapper.py => external/ollama.py} (69%) rename app/llm/{wrapper/open_ai_chat_wrapper.py => external/openai_chat.py} (82%) rename app/llm/{wrapper/open_ai_completion_wrapper.py => external/openai_completion.py} (77%) rename app/llm/{wrapper/open_ai_embedding_wrapper.py => external/openai_embeddings.py} (76%) rename app/llm/langchain/{iris_langchain_embedding.py => iris_langchain_embedding_model.py} (65%) delete mode 100644 app/llm/wrapper/__init__.py delete mode 100644 app/llm/wrapper/abstract_llm_wrapper.py diff --git a/app/llm/__init__.py b/app/llm/__init__.py index 33542f1c..aa06c47c 100644 --- a/app/llm/__init__.py +++ b/app/llm/__init__.py @@ -1,3 +1,3 @@ -from llm.request_handler_interface import RequestHandlerInterface -from llm.generation_arguments import * -from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel +from llm.request_handler_interface import RequestHandler +from llm.completion_arguments import * +from llm.basic_request_handler import BasicRequestHandler, DefaultModelId diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 12079901..a5d2ca15 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -1,48 +1,26 @@ from domain import IrisMessage -from llm import RequestHandlerInterface, CompletionArguments +from llm import RequestHandler, CompletionArguments from llm.llm_manager import LlmManager -from llm.wrapper.abstract_llm_wrapper import ( - AbstractLlmCompletionWrapper, - AbstractLlmChatCompletionWrapper, - AbstractLlmEmbeddingWrapper, -) -type BasicRequestHandlerModel = str - -class BasicRequestHandler(RequestHandlerInterface): - model: BasicRequestHandlerModel +class BasicRequestHandler(RequestHandler): + model_id: str llm_manager: LlmManager - def __init__(self, model: BasicRequestHandlerModel): - self.model = model + def __init__(self, model_id: str): + self.model_id = model_id self.llm_manager = LlmManager() - def completion(self, prompt: str, arguments: CompletionArguments) -> str: - llm = self.llm_manager.get_llm_by_id(self.model) - if isinstance(llm, AbstractLlmCompletionWrapper): - return llm.completion(prompt, arguments) - else: - raise NotImplementedError( - f"The LLM {llm.__str__()} does not support completion" - ) + def complete(self, prompt: str, arguments: CompletionArguments) -> str: + llm = self.llm_manager.get_by_id(self.model_id) + return llm.complete(prompt, arguments) - def chat_completion( + def chat( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: - llm = self.llm_manager.get_llm_by_id(self.model) - if isinstance(llm, AbstractLlmChatCompletionWrapper): - return llm.chat_completion(messages, arguments) - else: - raise NotImplementedError( - f"The LLM {llm.__str__()} does not support chat completion" - ) + llm = self.llm_manager.get_by_id(self.model_id) + return llm.chat(messages, arguments) - def create_embedding(self, text: str) -> list[float]: - llm = self.llm_manager.get_llm_by_id(self.model) - if isinstance(llm, AbstractLlmEmbeddingWrapper): - return llm.create_embedding(text) - else: - raise NotImplementedError( - f"The LLM {llm.__str__()} does not support embedding" - ) + def embed(self, text: str) -> list[float]: + llm = self.llm_manager.get_by_id(self.model_id) + return llm.embed(text) diff --git a/app/llm/generation_arguments.py b/app/llm/completion_arguments.py similarity index 100% rename from app/llm/generation_arguments.py rename to app/llm/completion_arguments.py diff --git a/app/llm/external/__init__.py b/app/llm/external/__init__.py new file mode 100644 index 00000000..62266b6f --- /dev/null +++ b/app/llm/external/__init__.py @@ -0,0 +1,21 @@ +from llm.external.model import LanguageModel +from llm.external.openai_completion import ( + DirectOpenAICompletionModel, + AzureOpenAICompletionModel, +) +from llm.external.openai_chat import DirectOpenAIChatModel, AzureOpenAIChatModel +from llm.external.openai_embeddings import ( + DirectOpenAIEmbeddingModel, + AzureOpenAIEmbeddingModel, +) +from llm.external.ollama import OllamaModel + +type AnyLLM = ( + DirectOpenAICompletionModel + | AzureOpenAICompletionModel + | DirectOpenAIChatModel + | AzureOpenAIChatModel + | DirectOpenAIEmbeddingModel + | AzureOpenAIEmbeddingModel + | OllamaModel +) diff --git a/app/llm/external/model.py b/app/llm/external/model.py new file mode 100644 index 00000000..c831009f --- /dev/null +++ b/app/llm/external/model.py @@ -0,0 +1,60 @@ +from abc import ABCMeta, abstractmethod +from pydantic import BaseModel + +from domain import IrisMessage +from llm import CompletionArguments + + +class LanguageModel(BaseModel, metaclass=ABCMeta): + """Abstract class for the llm wrappers""" + + id: str + name: str + description: str + + +class CompletionModel(LanguageModel, metaclass=ABCMeta): + """Abstract class for the llm completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass) -> bool: + return hasattr(subclass, "complete") and callable(subclass.complete) + + @abstractmethod + def complete(self, prompt: str, arguments: CompletionArguments) -> str: + """Create a completion from the prompt""" + raise NotImplementedError( + f"The LLM {self.__str__()} does not support completion" + ) + + +class ChatModel(LanguageModel, metaclass=ABCMeta): + """Abstract class for the llm chat completion wrappers""" + + @classmethod + def __subclasshook__(cls, subclass) -> bool: + return hasattr(subclass, "chat") and callable(subclass.chat) + + @abstractmethod + def chat( + self, messages: list[IrisMessage], arguments: CompletionArguments + ) -> IrisMessage: + """Create a completion from the chat messages""" + raise NotImplementedError( + f"The LLM {self.__str__()} does not support chat completion" + ) + + +class EmbeddingModel(LanguageModel, metaclass=ABCMeta): + """Abstract class for the llm embedding wrappers""" + + @classmethod + def __subclasshook__(cls, subclass) -> bool: + return hasattr(subclass, "embed") and callable(subclass.embed) + + @abstractmethod + def embed(self, text: str) -> list[float]: + """Create an embedding from the text""" + raise NotImplementedError( + f"The LLM {self.__str__()} does not support embeddings" + ) diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/external/ollama.py similarity index 69% rename from app/llm/wrapper/ollama_wrapper.py rename to app/llm/external/ollama.py index 4ea0e9b0..318a984d 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/external/ollama.py @@ -4,11 +4,7 @@ from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper.abstract_llm_wrapper import ( - AbstractLlmChatCompletionWrapper, - AbstractLlmCompletionWrapper, - AbstractLlmEmbeddingWrapper, -) +from llm.external.model import ChatModel, CompletionModel, EmbeddingModel def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: @@ -21,10 +17,10 @@ def convert_to_iris_message(message: Message) -> IrisMessage: return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"]) -class OllamaWrapper( - AbstractLlmCompletionWrapper, - AbstractLlmChatCompletionWrapper, - AbstractLlmEmbeddingWrapper, +class OllamaModel( + CompletionModel, + ChatModel, + EmbeddingModel, ): type: Literal["ollama"] model: str @@ -34,19 +30,19 @@ class OllamaWrapper( def model_post_init(self, __context: Any) -> None: self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?) - def completion(self, prompt: str, arguments: CompletionArguments) -> str: + def complete(self, prompt: str, arguments: CompletionArguments) -> str: response = self._client.generate(model=self.model, prompt=prompt) return response["response"] - def chat_completion( - self, messages: list[any], arguments: CompletionArguments - ) -> any: + def chat( + self, messages: list[IrisMessage], arguments: CompletionArguments + ) -> IrisMessage: response = self._client.chat( model=self.model, messages=convert_to_ollama_messages(messages) ) return convert_to_iris_message(response["message"]) - def create_embedding(self, text: str) -> list[float]: + def embed(self, text: str) -> list[float]: response = self._client.embeddings(model=self.model, prompt=text) return list(response) diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/external/openai_chat.py similarity index 82% rename from app/llm/wrapper/open_ai_chat_wrapper.py rename to app/llm/external/openai_chat.py index 6a605ad5..652df527 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/external/openai_chat.py @@ -1,11 +1,12 @@ from typing import Literal, Any + from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments -from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper +from llm.external.model import ChatModel def convert_to_open_ai_messages( @@ -22,13 +23,13 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: return IrisMessage(role=message_role, text=message.content) -class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper): +class OpenAIChatModel(ChatModel): model: str api_key: str _client: OpenAI - def chat_completion( - self, messages: list[any], arguments: CompletionArguments + def chat( + self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: response = self._client.chat.completions.create( model=self.model, @@ -40,7 +41,7 @@ def chat_completion( return convert_to_iris_message(response.choices[0].message) -class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): +class DirectOpenAIChatModel(OpenAIChatModel): type: Literal["openai_chat"] def model_post_init(self, __context: Any) -> None: @@ -50,7 +51,7 @@ def __str__(self): return f"OpenAIChat('{self.model}')" -class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): +class AzureOpenAIChatModel(OpenAIChatModel): type: Literal["azure_chat"] endpoint: str azure_deployment: str diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/external/openai_completion.py similarity index 77% rename from app/llm/wrapper/open_ai_completion_wrapper.py rename to app/llm/external/openai_completion.py index 22fe4ed2..449d2c5b 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/external/openai_completion.py @@ -3,15 +3,15 @@ from openai.lib.azure import AzureOpenAI from llm import CompletionArguments -from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper +from llm.external.model import CompletionModel -class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper): +class OpenAICompletionModel(CompletionModel): model: str api_key: str _client: OpenAI - def completion(self, prompt: str, arguments: CompletionArguments) -> any: + def complete(self, prompt: str, arguments: CompletionArguments) -> any: response = self._client.completions.create( model=self.model, prompt=prompt, @@ -22,7 +22,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: return response -class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): +class DirectOpenAICompletionModel(OpenAICompletionModel): type: Literal["openai_completion"] def model_post_init(self, __context: Any) -> None: @@ -32,7 +32,7 @@ def __str__(self): return f"OpenAICompletion('{self.model}')" -class AzureCompletionWrapper(BaseOpenAICompletionWrapper): +class AzureOpenAICompletionModel(OpenAICompletionModel): type: Literal["azure_completion"] endpoint: str azure_deployment: str diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/external/openai_embeddings.py similarity index 76% rename from app/llm/wrapper/open_ai_embedding_wrapper.py rename to app/llm/external/openai_embeddings.py index 99c397c9..66ceb0ba 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/external/openai_embeddings.py @@ -2,15 +2,15 @@ from openai import OpenAI from openai.lib.azure import AzureOpenAI -from llm.wrapper.abstract_llm_wrapper import AbstractLlmEmbeddingWrapper +from llm.external.model import EmbeddingModel -class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper): +class OpenAIEmbeddingModel(EmbeddingModel): model: str api_key: str _client: OpenAI - def create_embedding(self, text: str) -> list[float]: + def embed(self, text: str) -> list[float]: response = self._client.embeddings.create( model=self.model, input=text, @@ -19,7 +19,7 @@ def create_embedding(self, text: str) -> list[float]: return response.data[0].embedding -class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): +class DirectOpenAIEmbeddingModel(OpenAIEmbeddingModel): type: Literal["openai_embedding"] def model_post_init(self, __context: Any) -> None: @@ -29,7 +29,7 @@ def __str__(self): return f"OpenAIEmbedding('{self.model}')" -class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): +class AzureOpenAIEmbeddingModel(OpenAIEmbeddingModel): type: Literal["azure_embedding"] endpoint: str azure_deployment: str diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py index 4deb1372..f887cf17 100644 --- a/app/llm/langchain/__init__.py +++ b/app/llm/langchain/__init__.py @@ -1,3 +1,3 @@ from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel -from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel +from llm.langchain.iris_langchain_embedding_model import IrisLangchainEmbeddingModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index d0e558fa..ec78e963 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -9,7 +9,7 @@ from langchain_core.outputs.chat_generation import ChatGeneration from domain import IrisMessage, IrisMessageRole -from llm import RequestHandlerInterface, CompletionArguments +from llm import RequestHandler, CompletionArguments def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: @@ -35,9 +35,9 @@ def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessa class IrisLangchainChatModel(BaseChatModel): """Custom langchain chat model for our own request handler""" - request_handler: RequestHandlerInterface + request_handler: RequestHandler - def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: super().__init__(request_handler=request_handler, **kwargs) def _generate( @@ -48,7 +48,7 @@ def _generate( **kwargs: Any ) -> ChatResult: iris_messages = [convert_base_message_to_iris_message(m) for m in messages] - iris_message = self.request_handler.chat_completion( + iris_message = self.request_handler.chat( iris_messages, CompletionArguments(stop=stop) ) base_message = convert_iris_message_to_base_message(iris_message) diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py index 8a8b2e95..b0d056e2 100644 --- a/app/llm/langchain/iris_langchain_completion_model.py +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -5,15 +5,15 @@ from langchain_core.outputs import LLMResult from langchain_core.outputs.generation import Generation -from llm import RequestHandlerInterface, CompletionArguments +from llm import RequestHandler, CompletionArguments class IrisLangchainCompletionModel(BaseLLM): """Custom langchain chat model for our own request handler""" - request_handler: RequestHandlerInterface + request_handler: RequestHandler - def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: super().__init__(request_handler=request_handler, **kwargs) def _generate( @@ -26,9 +26,7 @@ def _generate( generations = [] args = CompletionArguments(stop=stop) for prompt in prompts: - completion = self.request_handler.completion( - prompt=prompt, arguments=args, **kwargs - ) + completion = self.request_handler.complete(prompt=prompt, arguments=args) generations.append([Generation(text=completion)]) return LLMResult(generations=generations) diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding_model.py similarity index 65% rename from app/llm/langchain/iris_langchain_embedding.py rename to app/llm/langchain/iris_langchain_embedding_model.py index 01a840ff..504fe46f 100644 --- a/app/llm/langchain/iris_langchain_embedding.py +++ b/app/llm/langchain/iris_langchain_embedding_model.py @@ -2,19 +2,19 @@ from langchain_core.embeddings import Embeddings -from llm import RequestHandlerInterface +from llm import RequestHandler class IrisLangchainEmbeddingModel(Embeddings): """Custom langchain embedding for our own request handler""" - request_handler: RequestHandlerInterface + request_handler: RequestHandler - def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: super().__init__(request_handler=request_handler, **kwargs) def embed_documents(self, texts: List[str]) -> List[List[float]]: return [self.embed_query(text) for text in texts] def embed_query(self, text: str) -> List[float]: - return self.request_handler.create_embedding(text) + return self.request_handler.embed(text) diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index af593d32..1abc02bd 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -5,16 +5,16 @@ import yaml from common import Singleton -from llm.wrapper import AbstractLlmWrapper, LlmWrapper +from llm.external import LanguageModel, AnyLLM # Small workaround to get pydantic discriminators working class LlmList(BaseModel): - llms: list[LlmWrapper] = Field(discriminator="type") + llms: list[AnyLLM] = Field(discriminator="type") class LlmManager(metaclass=Singleton): - entries: list[AbstractLlmWrapper] + entries: list[LanguageModel] def __init__(self): self.entries = [] diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py index 02a3f21c..16ac9646 100644 --- a/app/llm/request_handler_interface.py +++ b/app/llm/request_handler_interface.py @@ -1,36 +1,34 @@ from abc import ABCMeta, abstractmethod from domain import IrisMessage -from llm.generation_arguments import CompletionArguments +from llm.completion_arguments import CompletionArguments -class RequestHandlerInterface(metaclass=ABCMeta): +class RequestHandler(metaclass=ABCMeta): """Interface for the request handlers""" @classmethod - def __subclasshook__(cls, subclass): + def __subclasshook__(cls, subclass) -> bool: return ( - hasattr(subclass, "completion") - and callable(subclass.completion) - and hasattr(subclass, "chat_completion") - and callable(subclass.chat_completion) - and hasattr(subclass, "create_embedding") - and callable(subclass.create_embedding) + hasattr(subclass, "complete") + and callable(subclass.complete) + and hasattr(subclass, "chat") + and callable(subclass.chat) + and hasattr(subclass, "embed") + and callable(subclass.embed) ) @abstractmethod - def completion(self, prompt: str, arguments: CompletionArguments) -> str: + def complete(self, prompt: str, arguments: CompletionArguments) -> str: """Create a completion from the prompt""" raise NotImplementedError @abstractmethod - def chat_completion( - self, messages: list[any], arguments: CompletionArguments - ) -> IrisMessage: + def chat(self, messages: list[any], arguments: CompletionArguments) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError @abstractmethod - def create_embedding(self, text: str) -> list[float]: + def embed(self, text: str) -> list[float]: """Create an embedding from the text""" raise NotImplementedError diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py deleted file mode 100644 index c4807ec5..00000000 --- a/app/llm/wrapper/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper -from llm.wrapper.open_ai_completion_wrapper import ( - OpenAICompletionWrapper, - AzureCompletionWrapper, -) -from llm.wrapper.open_ai_chat_wrapper import ( - OpenAIChatCompletionWrapper, - AzureChatCompletionWrapper, -) -from llm.wrapper.open_ai_embedding_wrapper import ( - OpenAIEmbeddingWrapper, - AzureEmbeddingWrapper, -) -from llm.wrapper.ollama_wrapper import OllamaWrapper - -type LlmWrapper = ( - OpenAICompletionWrapper - | AzureCompletionWrapper - | OpenAIChatCompletionWrapper - | AzureChatCompletionWrapper - | OpenAIEmbeddingWrapper - | AzureEmbeddingWrapper - | OllamaWrapper -) diff --git a/app/llm/wrapper/abstract_llm_wrapper.py b/app/llm/wrapper/abstract_llm_wrapper.py deleted file mode 100644 index 057b3aca..00000000 --- a/app/llm/wrapper/abstract_llm_wrapper.py +++ /dev/null @@ -1,58 +0,0 @@ -from abc import ABCMeta, abstractmethod -from pydantic import BaseModel - -from domain import IrisMessage -from llm import CompletionArguments - - -class AbstractLlmWrapper(BaseModel, metaclass=ABCMeta): - """Abstract class for the llm wrappers""" - - id: str - name: str - description: str - - -class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): - """Abstract class for the llm completion wrappers""" - - @classmethod - def __subclasshook__(cls, subclass): - return hasattr(subclass, "completion") and callable(subclass.completion) - - @abstractmethod - def completion(self, prompt: str, arguments: CompletionArguments) -> str: - """Create a completion from the prompt""" - raise NotImplementedError - - -class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): - """Abstract class for the llm chat completion wrappers""" - - @classmethod - def __subclasshook__(cls, subclass): - return hasattr(subclass, "chat_completion") and callable( - subclass.chat_completion - ) - - @abstractmethod - def chat_completion( - self, messages: list[any], arguments: CompletionArguments - ) -> IrisMessage: - """Create a completion from the chat messages""" - raise NotImplementedError - - -class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta): - """Abstract class for the llm embedding wrappers""" - - @classmethod - def __subclasshook__(cls, subclass): - return hasattr(subclass, "create_embedding") and callable( - subclass.create_embedding - ) - - @abstractmethod - def create_embedding(self, text: str) -> list[float]: - """Create an embedding from the text""" - raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index f82b5836..3e76f66c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ black==24.1.1 flake8==7.0.0 pre-commit==3.6.1 pydantic==2.6.1 +PyYAML==6.0.1 langchain==0.1.6