From 626a92c1f5492df53780e82ab8bebbb1403b0357 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Wed, 14 Feb 2024 17:24:27 +0100 Subject: [PATCH] `LLM`: Add Langchain LLM wrapper classes (#57) --- app/llm/basic_request_handler.py | 6 +- app/llm/langchain/__init__.py | 3 + .../langchain/iris_langchain_chat_model.py | 60 +++++++++++++++++++ .../iris_langchain_completion_model.py | 37 ++++++++++++ app/llm/langchain/iris_langchain_embedding.py | 20 +++++++ app/llm/request_handler_interface.py | 2 +- requirements.txt | 1 + 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 app/llm/langchain/__init__.py create mode 100644 app/llm/langchain/iris_langchain_chat_model.py create mode 100644 app/llm/langchain/iris_langchain_completion_model.py create mode 100644 app/llm/langchain/iris_langchain_embedding.py diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 001d2dbb..12079901 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel): self.llm_manager = LlmManager() def completion(self, prompt: str, arguments: CompletionArguments) -> str: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmCompletionWrapper): return llm.completion(prompt, arguments) else: @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmChatCompletionWrapper): return llm.chat_completion(messages, arguments) else: @@ -39,7 +39,7 @@ def chat_completion( ) def create_embedding(self, text: str) -> list[float]: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmEmbeddingWrapper): return llm.create_embedding(text) else: diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py new file mode 100644 index 00000000..4deb1372 --- /dev/null +++ b/app/llm/langchain/__init__.py @@ -0,0 +1,3 @@ +from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel +from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel +from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py new file mode 100644 index 00000000..d0e558fa --- /dev/null +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -0,0 +1,60 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, +) +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_core.outputs.chat_generation import ChatGeneration + +from domain import IrisMessage, IrisMessageRole +from llm import RequestHandlerInterface, CompletionArguments + + +def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: + role_map = { + IrisMessageRole.USER: "human", + IrisMessageRole.ASSISTANT: "ai", + IrisMessageRole.SYSTEM: "system", + } + return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) + + +def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: + role_map = { + "human": IrisMessageRole.USER, + "ai": IrisMessageRole.ASSISTANT, + "system": IrisMessageRole.SYSTEM, + } + return IrisMessage( + text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) + ) + + +class IrisLangchainChatModel(BaseChatModel): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> ChatResult: + iris_messages = [convert_base_message_to_iris_message(m) for m in messages] + iris_message = self.request_handler.chat_completion( + iris_messages, CompletionArguments(stop=stop) + ) + base_message = convert_iris_message_to_base_message(iris_message) + chat_generation = ChatGeneration(message=base_message) + return ChatResult(generations=[chat_generation]) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py new file mode 100644 index 00000000..8a8b2e95 --- /dev/null +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import LLMResult +from langchain_core.outputs.generation import Generation + +from llm import RequestHandlerInterface, CompletionArguments + + +class IrisLangchainCompletionModel(BaseLLM): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> LLMResult: + generations = [] + args = CompletionArguments(stop=stop) + for prompt in prompts: + completion = self.request_handler.completion( + prompt=prompt, arguments=args, **kwargs + ) + generations.append([Generation(text=completion)]) + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding.py new file mode 100644 index 00000000..01a840ff --- /dev/null +++ b/app/llm/langchain/iris_langchain_embedding.py @@ -0,0 +1,20 @@ +from typing import List, Any + +from langchain_core.embeddings import Embeddings + +from llm import RequestHandlerInterface + + +class IrisLangchainEmbeddingModel(Embeddings): + """Custom langchain embedding for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + return self.request_handler.create_embedding(text) diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py index 5c15df30..02a3f21c 100644 --- a/app/llm/request_handler_interface.py +++ b/app/llm/request_handler_interface.py @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: @abstractmethod def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> [IrisMessage]: + ) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index 3b4afc16..f82b5836 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ black==24.1.1 flake8==7.0.0 pre-commit==3.6.1 pydantic==2.6.1 +langchain==0.1.6