Skip to content

Commit

Permalink
Add all required Langchain LLM wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Feb 14, 2024
1 parent 4af8a5f commit 16b72b0
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 10 deletions.
6 changes: 3 additions & 3 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel):
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
Expand All @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
Expand All @@ -39,7 +39,7 @@ def chat_completion(
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
Expand Down
2 changes: 2 additions & 0 deletions app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel
28 changes: 21 additions & 7 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,30 @@
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage
from domain import IrisMessage, IrisMessageRole
from llm import RequestHandlerInterface, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
return BaseMessage(content=iris_message.text, role=iris_message.role)
role_map = {
IrisMessageRole.USER: "human",
IrisMessageRole.ASSISTANT: "ai",
IrisMessageRole.SYSTEM: "system",
}
return BaseMessage(content=iris_message.text, type=role_map[iris_message.role])


def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
return IrisMessage(text=base_message.content, role=base_message.role)
role_map = {
"human": IrisMessageRole.USER,
"ai": IrisMessageRole.ASSISTANT,
"system": IrisMessageRole.SYSTEM,
}
return IrisMessage(
text=base_message.content, role=IrisMessageRole(role_map[base_message.type])
)


class IrisLangchainChatModel(BaseChatModel):
Expand All @@ -25,8 +38,7 @@ class IrisLangchainChatModel(BaseChatModel):
request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.request_handler = request_handler
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
Expand All @@ -35,11 +47,13 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat_completion(
messages, CompletionArguments(stop=stop)
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
return ChatResult(generations=[base_message])
chat_generation = ChatGeneration(message=base_message)
return ChatResult(generations=[chat_generation])

@property
def _llm_type(self) -> str:
Expand Down
37 changes: 37 additions & 0 deletions app/llm/langchain/iris_langchain_completion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation

from llm import RequestHandlerInterface, CompletionArguments


class IrisLangchainCompletionModel(BaseLLM):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
generations = []
args = CompletionArguments(stop=stop)
for prompt in prompts:
completion = self.request_handler.completion(
prompt=prompt, arguments=args, **kwargs
)
generations.append([Generation(text=completion)])
return LLMResult(generations=generations)

@property
def _llm_type(self) -> str:
return "Iris"
20 changes: 20 additions & 0 deletions app/llm/langchain/iris_langchain_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, Any

from langchain_core.embeddings import Embeddings

from llm import RequestHandlerInterface


class IrisLangchainEmbeddingModel(Embeddings):
"""Custom langchain embedding for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
return self.request_handler.create_embedding(text)

0 comments on commit 16b72b0

Please sign in to comment.