diff --git a/app/common/__init__.py b/app/common/__init__.py index 97e30c68..3f77d2e2 100644 --- a/app/common/__init__.py +++ b/app/common/__init__.py @@ -1 +1,5 @@ from common.singleton import Singleton +from common.message_converters import ( + convert_iris_message_to_langchain_message, + convert_langchain_message_to_iris_message, +) diff --git a/app/common/message_converters.py b/app/common/message_converters.py new file mode 100644 index 00000000..6835ec11 --- /dev/null +++ b/app/common/message_converters.py @@ -0,0 +1,29 @@ +from langchain_core.messages import BaseMessage + +from domain import IrisMessage, IrisMessageRole + + +def convert_iris_message_to_langchain_message(iris_message: IrisMessage) -> BaseMessage: + match iris_message.role: + case IrisMessageRole.USER: + role = "human" + case IrisMessageRole.ASSISTANT: + role = "ai" + case IrisMessageRole.SYSTEM: + role = "system" + case _: + raise ValueError(f"Unknown message role: {iris_message.role}") + return BaseMessage(content=iris_message.text, type=role) + + +def convert_langchain_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: + match base_message.type: + case "human": + role = IrisMessageRole.USER + case "ai": + role = IrisMessageRole.ASSISTANT + case "system": + role = IrisMessageRole.SYSTEM + case _: + raise ValueError(f"Unknown message type: {base_message.type}") + return IrisMessage(text=base_message.content, role=role) diff --git a/app/llm/__init__.py b/app/llm/__init__.py index aa60d467..ed54099b 100644 --- a/app/llm/__init__.py +++ b/app/llm/__init__.py @@ -1,3 +1,3 @@ -from llm.request_handler_interface import RequestHandler from llm.completion_arguments import * -from llm.basic_request_handler import BasicRequestHandler +from llm.external import * +from llm.request_handler import * diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index ec78e963..b824c64a 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -8,30 +8,13 @@ from langchain_core.outputs import ChatResult from langchain_core.outputs.chat_generation import ChatGeneration -from domain import IrisMessage, IrisMessageRole +from common import ( + convert_iris_message_to_langchain_message, + convert_langchain_message_to_iris_message, +) from llm import RequestHandler, CompletionArguments -def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: - role_map = { - IrisMessageRole.USER: "human", - IrisMessageRole.ASSISTANT: "ai", - IrisMessageRole.SYSTEM: "system", - } - return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) - - -def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: - role_map = { - "human": IrisMessageRole.USER, - "ai": IrisMessageRole.ASSISTANT, - "system": IrisMessageRole.SYSTEM, - } - return IrisMessage( - text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) - ) - - class IrisLangchainChatModel(BaseChatModel): """Custom langchain chat model for our own request handler""" @@ -47,11 +30,11 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> ChatResult: - iris_messages = [convert_base_message_to_iris_message(m) for m in messages] + iris_messages = [convert_langchain_message_to_iris_message(m) for m in messages] iris_message = self.request_handler.chat( iris_messages, CompletionArguments(stop=stop) ) - base_message = convert_iris_message_to_base_message(iris_message) + base_message = convert_iris_message_to_langchain_message(iris_message) chat_generation = ChatGeneration(message=base_message) return ChatResult(generations=[chat_generation]) diff --git a/app/llm/request_handler/__init__.py b/app/llm/request_handler/__init__.py new file mode 100644 index 00000000..8ad46295 --- /dev/null +++ b/app/llm/request_handler/__init__.py @@ -0,0 +1,2 @@ +from basic_request_handler import BasicRequestHandler +from request_handler_interface import RequestHandler diff --git a/app/llm/basic_request_handler.py b/app/llm/request_handler/basic_request_handler.py similarity index 100% rename from app/llm/basic_request_handler.py rename to app/llm/request_handler/basic_request_handler.py diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler/request_handler_interface.py similarity index 100% rename from app/llm/request_handler_interface.py rename to app/llm/request_handler/request_handler_interface.py