Skip to content

Commit

Permalink
Improve LlmManager
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Feb 11, 2024
1 parent 1f924db commit ea6a397
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 64 deletions.
12 changes: 6 additions & 6 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from llm import LlmManager
from llm import RequestHandlerInterface, CompletionArguments
from llm.wrapper import (
LlmCompletionWrapperInterface,
LlmChatCompletionWrapperInterface,
LlmEmbeddingWrapperInterface,
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str
Expand All @@ -20,7 +20,7 @@ def __init__(self, model: BasicRequestHandlerModel):

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmCompletionWrapperInterface):
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
Expand All @@ -31,7 +31,7 @@ def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmChatCompletionWrapperInterface):
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
Expand All @@ -40,7 +40,7 @@ def chat_completion(

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmEmbeddingWrapperInterface):
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
Expand Down
42 changes: 23 additions & 19 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import yaml

from common import Singleton
from llm.wrapper import LlmWrapperInterface
from llm.wrapper import AbstractLlmWrapper


def create_llm_wrapper(config: dict) -> LlmWrapperInterface:
def create_llm_wrapper(config: dict) -> AbstractLlmWrapper:
if config["type"] == "openai":
from llm.wrapper import OpenAICompletionWrapper

Expand All @@ -15,6 +15,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface:
from llm.wrapper import AzureCompletionWrapper

return AzureCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
Expand All @@ -25,12 +28,19 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface:
from llm.wrapper import OpenAIChatCompletionWrapper

return OpenAIChatCompletionWrapper(
model=config["model"], api_key=config["api_key"]
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
api_key=config["api_key"],
)
elif config["type"] == "azure_chat":
from llm.wrapper import AzureChatCompletionWrapper

return AzureChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
Expand All @@ -45,6 +55,9 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface:
from llm.wrapper import AzureEmbeddingWrapper

return AzureEmbeddingWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
Expand All @@ -55,34 +68,25 @@ def create_llm_wrapper(config: dict) -> LlmWrapperInterface:
from llm.wrapper import OllamaWrapper

return OllamaWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
host=config["host"],
)
else:
raise Exception(f"Unknown LLM type: {config['type']}")


class LlmManagerEntry:
id: str
llm: LlmWrapperInterface

def __init__(self, config: dict):
self.id = config["id"]
self.llm = create_llm_wrapper(config)

def __str__(self):
return f"{self.id}: {self.llm}"


class LlmManager(metaclass=Singleton):
llms: list[LlmManagerEntry]
entries: list[AbstractLlmWrapper]

def __init__(self):
self.llms = []
self.entries = []
self.load_llms()

def get_llm_by_id(self, llm_id):
for llm in self.llms:
for llm in self.entries:
if llm.id == llm_id:
return llm

Expand All @@ -94,4 +98,4 @@ def load_llms(self):
with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)

self.llms = [LlmManagerEntry(llm) for llm in loaded_llms]
self.entries = [create_llm_wrapper(llm) for llm in loaded_llms]
2 changes: 1 addition & 1 deletion app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from llm.wrapper.llm_wrapper_interface import *
from llm.wrapper.abstract_llm_wrapper import *
from llm.wrapper.open_ai_completion_wrapper import *
from llm.wrapper.open_ai_chat_wrapper import *
from llm.wrapper.open_ai_embedding_wrapper import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
from domain import IrisMessage
from llm import CompletionArguments

type LlmWrapperInterface = (
LlmCompletionWrapperInterface
| LlmChatCompletionWrapperInterface
| LlmEmbeddingWrapperInterface
)

class AbstractLlmWrapper(metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

class LlmCompletionWrapperInterface(metaclass=ABCMeta):
"""Interface for the llm completion wrappers"""
id: str
name: str
description: str

def __init__(self, id: str, name: str, description: str):
self.id = id
self.name = name
self.description = description


class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
Expand All @@ -23,8 +30,8 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
raise NotImplementedError


class LlmChatCompletionWrapperInterface(metaclass=ABCMeta):
"""Interface for the llm chat completion wrappers"""
class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
Expand All @@ -40,8 +47,8 @@ def chat_completion(
raise NotImplementedError


class LlmEmbeddingWrapperInterface(metaclass=ABCMeta):
"""Interface for the llm embedding wrappers"""
class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
Expand Down
15 changes: 8 additions & 7 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper import (
LlmChatCompletionWrapperInterface,
LlmCompletionWrapperInterface,
LlmEmbeddingWrapperInterface,
AbstractLlmChatCompletionWrapper,
AbstractLlmCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)


Expand All @@ -20,12 +20,13 @@ def convert_to_iris_message(message: Message) -> IrisMessage:


class OllamaWrapper(
LlmCompletionWrapperInterface,
LlmChatCompletionWrapperInterface,
LlmEmbeddingWrapperInterface,
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
):

def __init__(self, model: str, host: str):
def __init__(self, model: str, host: str, **kwargs):
super().__init__(**kwargs)
self.client = Client(host=host) # TODO: Add authentication (httpx auth?)
self.model = model

Expand Down
18 changes: 10 additions & 8 deletions app/llm/wrapper/open_ai_chat_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper import LlmChatCompletionWrapperInterface
from llm.wrapper import AbstractLlmChatCompletionWrapper


def convert_to_open_ai_messages(
Expand All @@ -14,15 +14,16 @@ def convert_to_open_ai_messages(
]


def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage:
def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage:
# Get IrisMessageRole from the string message.role
message_role = IrisMessageRole(message.role)
return IrisMessage(role=message_role, text=message.content)


class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface):
class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper):

def __init__(self, client, model: str):
def __init__(self, client, model: str, **kwargs):
super().__init__(**kwargs)
self.client = client
self.model = model

Expand All @@ -41,12 +42,12 @@ def chat_completion(

class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):

def __init__(self, model: str, api_key: str):
def __init__(self, model: str, api_key: str, **kwargs):
from openai import OpenAI

client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model)
super().__init__(client, model, **kwargs)

def __str__(self):
return f"OpenAIChat('{self.model}')"
Expand All @@ -61,6 +62,7 @@ def __init__(
azure_deployment: str,
api_version: str,
api_key: str,
**kwargs,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
Expand All @@ -69,7 +71,7 @@ def __init__(
api_key=api_key,
)
model = model
super().__init__(client, model)
super().__init__(client, model, **kwargs)

def __str__(self):
return f"AzureChat('{self.model}')"
14 changes: 8 additions & 6 deletions app/llm/wrapper/open_ai_completion_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from openai.lib.azure import AzureOpenAI

from llm import CompletionArguments
from llm.wrapper import LlmCompletionWrapperInterface
from llm.wrapper import AbstractLlmCompletionWrapper


class BaseOpenAICompletionWrapper(LlmCompletionWrapperInterface):
class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper):

def __init__(self, client, model: str):
def __init__(self, client, model: str, **kwargs):
super().__init__(**kwargs)
self.client = client
self.model = model

Expand All @@ -24,10 +25,10 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any:

class OpenAICompletionWrapper(BaseOpenAICompletionWrapper):

def __init__(self, model: str, api_key: str):
def __init__(self, model: str, api_key: str, **kwargs):
client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model)
super().__init__(client, model, **kwargs)

def __str__(self):
return f"OpenAICompletion('{self.model}')"
Expand All @@ -42,6 +43,7 @@ def __init__(
azure_deployment: str,
api_version: str,
api_key: str,
**kwargs,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
Expand All @@ -50,7 +52,7 @@ def __init__(
api_key=api_key,
)
model = model
super().__init__(client, model)
super().__init__(client, model, **kwargs)

def __str__(self):
return f"AzureCompletion('{self.model}')"
14 changes: 8 additions & 6 deletions app/llm/wrapper/open_ai_embedding_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from openai.lib.azure import AzureOpenAI

from llm.wrapper import (
LlmEmbeddingWrapperInterface,
AbstractLlmEmbeddingWrapper,
)


class BaseOpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface):
class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper):

def __init__(self, client, model: str):
def __init__(self, client, model: str, **kwargs):
super().__init__(**kwargs)
self.client = client
self.model = model

Expand All @@ -23,10 +24,10 @@ def create_embedding(self, text: str) -> list[float]:

class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):

def __init__(self, model: str, api_key: str):
def __init__(self, model: str, api_key: str, **kwargs):
client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model)
super().__init__(client, model, **kwargs)

def __str__(self):
return f"OpenAIEmbedding('{self.model}')"
Expand All @@ -41,6 +42,7 @@ def __init__(
azure_deployment: str,
api_version: str,
api_key: str,
**kwargs,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
Expand All @@ -49,7 +51,7 @@ def __init__(
api_key=api_key,
)
model = model
super().__init__(client, model)
super().__init__(client, model, **kwargs)

def __str__(self):
return f"AzureEmbedding('{self.model}')"

0 comments on commit ea6a397

Please sign in to comment.