Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: Add Langchain LLM wrapper classes #57

Merged
merged 13 commits into from
Feb 14, 2024
6 changes: 3 additions & 3 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model: BasicRequestHandlerModel):
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
Expand All @@ -30,7 +30,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
Expand All @@ -39,7 +39,7 @@ def chat_completion(
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
Expand Down
3 changes: 3 additions & 0 deletions app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel
60 changes: 60 additions & 0 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage, IrisMessageRole
from llm import RequestHandlerInterface, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
role_map = {
IrisMessageRole.USER: "human",
IrisMessageRole.ASSISTANT: "ai",
IrisMessageRole.SYSTEM: "system",
}
return BaseMessage(content=iris_message.text, type=role_map[iris_message.role])


def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
role_map = {
"human": IrisMessageRole.USER,
"ai": IrisMessageRole.ASSISTANT,
"system": IrisMessageRole.SYSTEM,
}
bassner marked this conversation as resolved.
Show resolved Hide resolved
return IrisMessage(
text=base_message.content, role=IrisMessageRole(role_map[base_message.type])
)


class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat_completion(
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
return ChatResult(generations=[chat_generation])

@property
def _llm_type(self) -> str:
return "Iris"
37 changes: 37 additions & 0 deletions app/llm/langchain/iris_langchain_completion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation

from llm import RequestHandlerInterface, CompletionArguments


class IrisLangchainCompletionModel(BaseLLM):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
generations = []
args = CompletionArguments(stop=stop)
for prompt in prompts:
completion = self.request_handler.completion(
prompt=prompt, arguments=args, **kwargs
)
generations.append([Generation(text=completion)])
return LLMResult(generations=generations)

@property
def _llm_type(self) -> str:
return "Iris"
20 changes: 20 additions & 0 deletions app/llm/langchain/iris_langchain_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, Any

from langchain_core.embeddings import Embeddings

from llm import RequestHandlerInterface


class IrisLangchainEmbeddingModel(Embeddings):
"""Custom langchain embedding for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
return self.request_handler.create_embedding(text)
2 changes: 1 addition & 1 deletion app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ black==24.1.1
flake8==7.0.0
pre-commit==3.6.1
pydantic==2.6.1
langchain==0.1.6
Loading