Skip to content

Commit

Permalink
Merge branch 'main' into refactor/llm-code
Browse files Browse the repository at this point in the history
# Conflicts:
#	app/llm/basic_request_handler.py
#	app/llm/request_handler_interface.py
#	requirements.txt
  • Loading branch information
MichaelOwenDyer committed Feb 14, 2024
2 parents b9d7719 + 626a92c commit b3e5e8d
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 1 deletion.
3 changes: 3 additions & 0 deletions app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel
60 changes: 60 additions & 0 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage, IrisMessageRole
from llm import RequestHandlerInterface, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
role_map = {
IrisMessageRole.USER: "human",
IrisMessageRole.ASSISTANT: "ai",
IrisMessageRole.SYSTEM: "system",
}
return BaseMessage(content=iris_message.text, type=role_map[iris_message.role])


def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
role_map = {
"human": IrisMessageRole.USER,
"ai": IrisMessageRole.ASSISTANT,
"system": IrisMessageRole.SYSTEM,
}
return IrisMessage(
text=base_message.content, role=IrisMessageRole(role_map[base_message.type])
)


class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat_completion(
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
return ChatResult(generations=[chat_generation])

@property
def _llm_type(self) -> str:
return "Iris"
37 changes: 37 additions & 0 deletions app/llm/langchain/iris_langchain_completion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation

from llm import RequestHandlerInterface, CompletionArguments


class IrisLangchainCompletionModel(BaseLLM):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
generations = []
args = CompletionArguments(stop=stop)
for prompt in prompts:
completion = self.request_handler.completion(
prompt=prompt, arguments=args, **kwargs
)
generations.append([Generation(text=completion)])
return LLMResult(generations=generations)

@property
def _llm_type(self) -> str:
return "Iris"
20 changes: 20 additions & 0 deletions app/llm/langchain/iris_langchain_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, Any

from langchain_core.embeddings import Embeddings

from llm import RequestHandlerInterface


class IrisLangchainEmbeddingModel(Embeddings):
"""Custom langchain embedding for our own request handler"""

request_handler: RequestHandlerInterface

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
return self.request_handler.create_embedding(text)
2 changes: 1 addition & 1 deletion app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str:
raise NotImplementedError

@abstractmethod
def chat(self, messages: list[any], arguments: CompletionArguments) -> [IrisMessage]:
def chat(self, messages: list[any], arguments: CompletionArguments) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ flake8~=7.0.0
pre-commit~=3.6.1
pydantic~=2.6.1
PyYAML~=6.0.1
langchain~=0.1.6

0 comments on commit b3e5e8d

Please sign in to comment.