diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py new file mode 100644 index 00000000..1f75540b --- /dev/null +++ b/app/llm/langchain/__init__.py @@ -0,0 +1 @@ +from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py new file mode 100644 index 00000000..41f0df4e --- /dev/null +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -0,0 +1,46 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, +) +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult + +from domain import IrisMessage +from llm import RequestHandlerInterface, CompletionArguments + + +def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: + return BaseMessage(content=iris_message.text, role=iris_message.role) + + +def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: + return IrisMessage(text=base_message.content, role=base_message.role) + + +class IrisLangchainChatModel(BaseChatModel): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.request_handler = request_handler + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> ChatResult: + iris_message = self.request_handler.chat_completion( + messages, CompletionArguments(stop=stop) + ) + base_message = convert_iris_message_to_base_message(iris_message) + return ChatResult(generations=[base_message]) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py index 5c15df30..02a3f21c 100644 --- a/app/llm/request_handler_interface.py +++ b/app/llm/request_handler_interface.py @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: @abstractmethod def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> [IrisMessage]: + ) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index 71c9e37e..2796179f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ black==24.1.1 flake8==7.0.0 pre-commit==3.6.0 pydantic==2.6.1 +langchain==0.1.6