Skip to content

Commit

Permalink
LLM: Add Langchain LLM wrapper classes (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus authored Feb 14, 2024
1 parent 3171048 commit b58ea21
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 4 deletions.
6 changes: 3 additions & 3 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel):
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
Expand All @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
Expand All @@ -39,7 +39,7 @@ def chat_completion(
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
Expand Down
3 changes: 3 additions & 0 deletions app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel
60 changes: 60 additions & 0 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage, IrisMessageRole
from llm import RequestHandlerInterface, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
role_map = {
IrisMessageRole.USER: "human",
IrisMessageRole.ASSISTANT: "ai",
IrisMessageRole.SYSTEM: "system",
}
return BaseMessage(content=iris_message.text, type=role_map[iris_message.role])


def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
role_map = {
"human": IrisMessageRole.USER,
"ai": IrisMessageRole.ASSISTANT,
"system": IrisMessageRole.SYSTEM,
}
return IrisMessage(
text=base_message.content, role=IrisMessageRole(role_map[base_message.type])
)


class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat_completion(
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
return ChatResult(generations=[chat_generation])

@property
def _llm_type(self) -> str:
return "Iris"
37 changes: 37 additions & 0 deletions app/llm/langchain/iris_langchain_completion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation

from llm import RequestHandlerInterface, CompletionArguments


class IrisLangchainCompletionModel(BaseLLM):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
generations = []
args = CompletionArguments(stop=stop)
for prompt in prompts:
completion = self.request_handler.completion(
prompt=prompt, arguments=args, **kwargs
)
generations.append([Generation(text=completion)])
return LLMResult(generations=generations)

@property
def _llm_type(self) -> str:
return "Iris"
20 changes: 20 additions & 0 deletions app/llm/langchain/iris_langchain_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, Any

from langchain_core.embeddings import Embeddings

from llm import RequestHandlerInterface


class IrisLangchainEmbeddingModel(Embeddings):
"""Custom langchain embedding for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
return self.request_handler.create_embedding(text)
2 changes: 1 addition & 1 deletion app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ black==24.1.1
flake8==7.0.0
pre-commit==3.6.1
pydantic==2.6.1
langchain==0.1.6

0 comments on commit b58ea21

Please sign in to comment.