Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring suggestions #59

Merged
merged 8 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from llm.request_handler_interface import RequestHandlerInterface
from llm.generation_arguments import *
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
from llm.request_handler_interface import RequestHandler
from llm.completion_arguments import *
from llm.basic_request_handler import BasicRequestHandler, DefaultModelId
52 changes: 14 additions & 38 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,24 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
class BasicRequestHandler(RequestHandler):
model_id: str
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
def __init__(self, model_id: str):
self.model_id = model_id
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
MichaelOwenDyer marked this conversation as resolved.
Show resolved Hide resolved
)
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.complete(prompt, arguments)

def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)
def chat(self, messages: list[IrisMessage], arguments: CompletionArguments) -> IrisMessage:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.chat(messages, arguments)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
def embed(self, text: str) -> list[float]:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.embed(text)
File renamed without changes.
8 changes: 4 additions & 4 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage, IrisMessageRole
from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
Expand All @@ -35,9 +35,9 @@ def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessa
class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface
request_handler: RequestHandler

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
Expand All @@ -48,7 +48,7 @@ def _generate(
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat_completion(
iris_message = self.request_handler.chat(
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
Expand Down
10 changes: 5 additions & 5 deletions app/llm/langchain/iris_langchain_completion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation

from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments


class IrisLangchainCompletionModel(BaseLLM):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface
request_handler: RequestHandler

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
Expand All @@ -26,8 +26,8 @@ def _generate(
generations = []
args = CompletionArguments(stop=stop)
for prompt in prompts:
completion = self.request_handler.completion(
prompt=prompt, arguments=args, **kwargs
completion = self.request_handler.complete(
prompt=prompt, arguments=args
)
generations.append([Generation(text=completion)])
return LLMResult(generations=generations)
Expand Down
8 changes: 4 additions & 4 deletions app/llm/langchain/iris_langchain_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from langchain_core.embeddings import Embeddings

from llm import RequestHandlerInterface
from llm import RequestHandler


class IrisLangchainEmbeddingModel(Embeddings):
"""Custom langchain embedding for our own request handler"""

request_handler: RequestHandlerInterface
request_handler: RequestHandler

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
return self.request_handler.create_embedding(text)
return self.request_handler.embed(text)
38 changes: 19 additions & 19 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@
import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper, LlmWrapper
from llm.wrapper import LanguageModel, AnyLLM


# Small workaround to get pydantic discriminators working
class LlmList(BaseModel):
llms: list[LlmWrapper] = Field(discriminator="type")
class LLMList(BaseModel):
llms: list[AnyLLM] = Field(discriminator="type")


class LlmManager(metaclass=Singleton):
entries: list[AbstractLlmWrapper]
def load_llms() -> dict[str, LanguageModel]:
path = os.environ.get("LLM_CONFIG_PATH")
assert path, "LLM_CONFIG_PATH not set"

def __init__(self):
self.entries = []
self.load_llms()
with open(path, "r") as file:
yaml_dict = yaml.safe_load(file)

def get_llm_by_id(self, llm_id):
for llm in self.entries:
if llm.id == llm_id:
return llm
llms = LLMList.model_validate({"llms": yaml_dict}).llms
return {
llm.id: llm for llm in llms
}

def load_llms(self):
path = os.environ.get("LLM_CONFIG_PATH")
if not path:
raise Exception("LLM_CONFIG_PATH not set")

with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)
class LlmManager(metaclass=Singleton):
models_by_id: dict[str, LanguageModel]

def __init__(self):
self.models_by_id = load_llms()

self.entries = LlmList.parse_obj({"llms": loaded_llms}).llms
def get_by_id(self, llm_id):
return self.models_by_id[llm_id]
23 changes: 9 additions & 14 deletions app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,31 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm.generation_arguments import CompletionArguments
from llm.completion_arguments import CompletionArguments


class RequestHandlerInterface(metaclass=ABCMeta):
class RequestHandler(metaclass=ABCMeta):
"""Interface for the request handlers"""

@classmethod
def __subclasshook__(cls, subclass):
def __subclasshook__(cls, subclass) -> bool:
return (
hasattr(subclass, "completion")
and callable(subclass.completion)
and hasattr(subclass, "chat_completion")
and callable(subclass.chat_completion)
and hasattr(subclass, "create_embedding")
and callable(subclass.create_embedding)
hasattr(subclass, "complete") and callable(subclass.complete)
and hasattr(subclass, "chat") and callable(subclass.chat)
and hasattr(subclass, "embed") and callable(subclass.embed)
)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> IrisMessage:
def chat(self, messages: list[any], arguments: CompletionArguments) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
def embed(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
35 changes: 13 additions & 22 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper
from llm.wrapper.open_ai_completion_wrapper import (
OpenAICompletionWrapper,
AzureCompletionWrapper,
)
from llm.wrapper.open_ai_chat_wrapper import (
OpenAIChatCompletionWrapper,
AzureChatCompletionWrapper,
)
from llm.wrapper.open_ai_embedding_wrapper import (
OpenAIEmbeddingWrapper,
AzureEmbeddingWrapper,
)
from llm.wrapper.ollama_wrapper import OllamaWrapper
from llm.wrapper.model import LanguageModel
from llm.wrapper.openai_completion import NativeOpenAICompletionModel, AzureOpenAICompletionModel
from llm.wrapper.openai_chat import NativeOpenAIChatModel, AzureOpenAIChatModel
from llm.wrapper.openai_embeddings import NativeOpenAIEmbeddingModel, AzureOpenAIEmbeddingModel
from llm.wrapper.ollama import OllamaModel

type LlmWrapper = (
OpenAICompletionWrapper
| AzureCompletionWrapper
| OpenAIChatCompletionWrapper
| AzureChatCompletionWrapper
| OpenAIEmbeddingWrapper
| AzureEmbeddingWrapper
| OllamaWrapper
type AnyLLM = (
NativeOpenAICompletionModel
| AzureOpenAICompletionModel
| NativeOpenAIChatModel
| AzureOpenAIChatModel
| NativeOpenAIEmbeddingModel
| AzureOpenAIEmbeddingModel
| OllamaModel
)
58 changes: 0 additions & 58 deletions app/llm/wrapper/abstract_llm_wrapper.py

This file was deleted.

Loading
Loading