Skip to content

Commit

Permalink
WIP: Add langchain chat model wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Feb 11, 2024
1 parent b06d20f commit 4af8a5f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
46 changes: 46 additions & 0 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult

from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
return BaseMessage(content=iris_message.text, role=iris_message.role)


def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
return IrisMessage(text=base_message.content, role=base_message.role)


class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.request_handler = request_handler

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
iris_message = self.request_handler.chat_completion(
messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
return ChatResult(generations=[base_message])

@property
def _llm_type(self) -> str:
return "Iris"
2 changes: 1 addition & 1 deletion app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ black==24.1.1
flake8==7.0.0
pre-commit==3.6.0
pydantic==2.6.1
langchain==0.1.6

0 comments on commit 4af8a5f

Please sign in to comment.