diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py new file mode 100644 index 00000000..4deb1372 --- /dev/null +++ b/app/llm/langchain/__init__.py @@ -0,0 +1,3 @@ +from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel +from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel +from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py new file mode 100644 index 00000000..d0e558fa --- /dev/null +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -0,0 +1,60 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, +) +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_core.outputs.chat_generation import ChatGeneration + +from domain import IrisMessage, IrisMessageRole +from llm import RequestHandlerInterface, CompletionArguments + + +def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: + role_map = { + IrisMessageRole.USER: "human", + IrisMessageRole.ASSISTANT: "ai", + IrisMessageRole.SYSTEM: "system", + } + return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) + + +def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: + role_map = { + "human": IrisMessageRole.USER, + "ai": IrisMessageRole.ASSISTANT, + "system": IrisMessageRole.SYSTEM, + } + return IrisMessage( + text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) + ) + + +class IrisLangchainChatModel(BaseChatModel): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> ChatResult: + iris_messages = [convert_base_message_to_iris_message(m) for m in messages] + iris_message = self.request_handler.chat_completion( + iris_messages, CompletionArguments(stop=stop) + ) + base_message = convert_iris_message_to_base_message(iris_message) + chat_generation = ChatGeneration(message=base_message) + return ChatResult(generations=[chat_generation]) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py new file mode 100644 index 00000000..8a8b2e95 --- /dev/null +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import LLMResult +from langchain_core.outputs.generation import Generation + +from llm import RequestHandlerInterface, CompletionArguments + + +class IrisLangchainCompletionModel(BaseLLM): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> LLMResult: + generations = [] + args = CompletionArguments(stop=stop) + for prompt in prompts: + completion = self.request_handler.completion( + prompt=prompt, arguments=args, **kwargs + ) + generations.append([Generation(text=completion)]) + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding.py new file mode 100644 index 00000000..01a840ff --- /dev/null +++ b/app/llm/langchain/iris_langchain_embedding.py @@ -0,0 +1,20 @@ +from typing import List, Any + +from langchain_core.embeddings import Embeddings + +from llm import RequestHandlerInterface + + +class IrisLangchainEmbeddingModel(Embeddings): + """Custom langchain embedding for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + return self.request_handler.create_embedding(text) diff --git a/app/llm/request_handler_interface.py b/app/llm/request_handler_interface.py index a9421dd2..6f2ad666 100644 --- a/app/llm/request_handler_interface.py +++ b/app/llm/request_handler_interface.py @@ -21,7 +21,7 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str: raise NotImplementedError @abstractmethod - def chat(self, messages: list[any], arguments: CompletionArguments) -> [IrisMessage]: + def chat(self, messages: list[any], arguments: CompletionArguments) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index 6c00ab79..8bc38ea7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ flake8~=7.0.0 pre-commit~=3.6.1 pydantic~=2.6.1 PyYAML~=6.0.1 +langchain~=0.1.6