diff --git a/app/llm/basic_request_handler.py b/app/llm/basic_request_handler.py index 001d2dbb..12079901 100644 --- a/app/llm/basic_request_handler.py +++ b/app/llm/basic_request_handler.py @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel): self.llm_manager = LlmManager() def completion(self, prompt: str, arguments: CompletionArguments) -> str: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmCompletionWrapper): return llm.completion(prompt, arguments) else: @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str: def chat_completion( self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmChatCompletionWrapper): return llm.chat_completion(messages, arguments) else: @@ -39,7 +39,7 @@ def chat_completion( ) def create_embedding(self, text: str) -> list[float]: - llm = self.llm_manager.get_llm_by_id(self.model).llm + llm = self.llm_manager.get_llm_by_id(self.model) if isinstance(llm, AbstractLlmEmbeddingWrapper): return llm.create_embedding(text) else: diff --git a/app/llm/langchain/__init__.py b/app/llm/langchain/__init__.py index 1f75540b..4deb1372 100644 --- a/app/llm/langchain/__init__.py +++ b/app/llm/langchain/__init__.py @@ -1 +1,3 @@ +from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel +from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index 41f0df4e..d0e558fa 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -6,17 +6,30 @@ ) from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult +from langchain_core.outputs.chat_generation import ChatGeneration -from domain import IrisMessage +from domain import IrisMessage, IrisMessageRole from llm import RequestHandlerInterface, CompletionArguments def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: - return BaseMessage(content=iris_message.text, role=iris_message.role) + role_map = { + IrisMessageRole.USER: "human", + IrisMessageRole.ASSISTANT: "ai", + IrisMessageRole.SYSTEM: "system", + } + return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: - return IrisMessage(text=base_message.content, role=base_message.role) + role_map = { + "human": IrisMessageRole.USER, + "ai": IrisMessageRole.ASSISTANT, + "system": IrisMessageRole.SYSTEM, + } + return IrisMessage( + text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) + ) class IrisLangchainChatModel(BaseChatModel): @@ -25,8 +38,7 @@ class IrisLangchainChatModel(BaseChatModel): request_handler: RequestHandlerInterface def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.request_handler = request_handler + super().__init__(request_handler=request_handler, **kwargs) def _generate( self, @@ -35,11 +47,13 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> ChatResult: + iris_messages = [convert_base_message_to_iris_message(m) for m in messages] iris_message = self.request_handler.chat_completion( - messages, CompletionArguments(stop=stop) + iris_messages, CompletionArguments(stop=stop) ) base_message = convert_iris_message_to_base_message(iris_message) - return ChatResult(generations=[base_message]) + chat_generation = ChatGeneration(message=base_message) + return ChatResult(generations=[chat_generation]) @property def _llm_type(self) -> str: diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py new file mode 100644 index 00000000..8a8b2e95 --- /dev/null +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Any + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import LLMResult +from langchain_core.outputs.generation import Generation + +from llm import RequestHandlerInterface, CompletionArguments + + +class IrisLangchainCompletionModel(BaseLLM): + """Custom langchain chat model for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> LLMResult: + generations = [] + args = CompletionArguments(stop=stop) + for prompt in prompts: + completion = self.request_handler.completion( + prompt=prompt, arguments=args, **kwargs + ) + generations.append([Generation(text=completion)]) + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + return "Iris" diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding.py new file mode 100644 index 00000000..01a840ff --- /dev/null +++ b/app/llm/langchain/iris_langchain_embedding.py @@ -0,0 +1,20 @@ +from typing import List, Any + +from langchain_core.embeddings import Embeddings + +from llm import RequestHandlerInterface + + +class IrisLangchainEmbeddingModel(Embeddings): + """Custom langchain embedding for our own request handler""" + + request_handler: RequestHandlerInterface + + def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + super().__init__(request_handler=request_handler, **kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + return self.request_handler.create_embedding(text)