Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring suggestions #59

Merged
merged 8 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from llm.request_handler_interface import RequestHandlerInterface
from llm.generation_arguments import *
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
from llm.request_handler_interface import RequestHandler
from llm.completion_arguments import *
from llm.basic_request_handler import BasicRequestHandler, DefaultModelId
50 changes: 14 additions & 36 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
class BasicRequestHandler(RequestHandler):
model_id: str
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
def __init__(self, model_id: str):
self.model_id = model_id
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
MichaelOwenDyer marked this conversation as resolved.
Show resolved Hide resolved
)
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.complete(prompt, arguments)

def chat_completion(
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)
llm = self.llm_manager.get_by_id(self.model_id)
return llm.chat(messages, arguments)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model)
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
def embed(self, text: str) -> list[float]:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.embed(text)
File renamed without changes.
21 changes: 21 additions & 0 deletions app/llm/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from llm.external.model import LanguageModel
from llm.external.openai_completion import (
DirectOpenAICompletionModel,
AzureOpenAICompletionModel,
)
from llm.external.openai_chat import DirectOpenAIChatModel, AzureOpenAIChatModel
from llm.external.openai_embeddings import (
DirectOpenAIEmbeddingModel,
AzureOpenAIEmbeddingModel,
)
from llm.external.ollama import OllamaModel

type AnyLLM = (
DirectOpenAICompletionModel
| AzureOpenAICompletionModel
| DirectOpenAIChatModel
| AzureOpenAIChatModel
| DirectOpenAIEmbeddingModel
| AzureOpenAIEmbeddingModel
| OllamaModel
)
60 changes: 60 additions & 0 deletions app/llm/external/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from domain import IrisMessage
from llm import CompletionArguments


class LanguageModel(BaseModel, metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

id: str
name: str
description: str


class CompletionModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass) -> bool:
return hasattr(subclass, "complete") and callable(subclass.complete)

@abstractmethod
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support completion"
)


class ChatModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass) -> bool:
return hasattr(subclass, "chat") and callable(subclass.chat)

@abstractmethod
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support chat completion"
)


class EmbeddingModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass) -> bool:
return hasattr(subclass, "embed") and callable(subclass.embed)

@abstractmethod
def embed(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support embeddings"
)
24 changes: 10 additions & 14 deletions app/llm/wrapper/ollama_wrapper.py → app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmChatCompletionWrapper,
AbstractLlmCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)
from llm.external.model import ChatModel, CompletionModel, EmbeddingModel


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
Expand All @@ -21,10 +17,10 @@ def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"])


class OllamaWrapper(
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
class OllamaModel(
CompletionModel,
ChatModel,
EmbeddingModel,
):
type: Literal["ollama"]
model: str
Expand All @@ -34,19 +30,19 @@ class OllamaWrapper(
def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
response = self._client.generate(model=self.model, prompt=prompt)
return response["response"]

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
response = self._client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
)
return convert_to_iris_message(response["message"])

def create_embedding(self, text: str) -> list[float]:
def embed(self, text: str) -> list[float]:
response = self._client.embeddings(model=self.model, prompt=text)
return list(response)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Literal, Any

from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper
from llm.external.model import ChatModel


def convert_to_open_ai_messages(
Expand All @@ -22,13 +23,13 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage:
return IrisMessage(role=message_role, text=message.content)


class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper):
class OpenAIChatModel(ChatModel):
model: str
api_key: str
_client: OpenAI

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
response = self._client.chat.completions.create(
model=self.model,
Expand All @@ -40,7 +41,7 @@ def chat_completion(
return convert_to_iris_message(response.choices[0].message)


class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):
class DirectOpenAIChatModel(OpenAIChatModel):
type: Literal["openai_chat"]

def model_post_init(self, __context: Any) -> None:
Expand All @@ -50,7 +51,7 @@ def __str__(self):
return f"OpenAIChat('{self.model}')"


class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):
class AzureOpenAIChatModel(OpenAIChatModel):
type: Literal["azure_chat"]
endpoint: str
azure_deployment: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from openai.lib.azure import AzureOpenAI

from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper
from llm.external.model import CompletionModel


class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper):
class OpenAICompletionModel(CompletionModel):
model: str
api_key: str
_client: OpenAI

def completion(self, prompt: str, arguments: CompletionArguments) -> any:
def complete(self, prompt: str, arguments: CompletionArguments) -> any:
response = self._client.completions.create(
model=self.model,
prompt=prompt,
Expand All @@ -22,7 +22,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any:
return response


class OpenAICompletionWrapper(BaseOpenAICompletionWrapper):
class DirectOpenAICompletionModel(OpenAICompletionModel):
type: Literal["openai_completion"]

def model_post_init(self, __context: Any) -> None:
Expand All @@ -32,7 +32,7 @@ def __str__(self):
return f"OpenAICompletion('{self.model}')"


class AzureCompletionWrapper(BaseOpenAICompletionWrapper):
class AzureOpenAICompletionModel(OpenAICompletionModel):
type: Literal["azure_completion"]
endpoint: str
azure_deployment: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI

from llm.wrapper.abstract_llm_wrapper import AbstractLlmEmbeddingWrapper
from llm.external.model import EmbeddingModel


class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper):
class OpenAIEmbeddingModel(EmbeddingModel):
model: str
api_key: str
_client: OpenAI

def create_embedding(self, text: str) -> list[float]:
def embed(self, text: str) -> list[float]:
response = self._client.embeddings.create(
model=self.model,
input=text,
Expand All @@ -19,7 +19,7 @@ def create_embedding(self, text: str) -> list[float]:
return response.data[0].embedding


class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):
class DirectOpenAIEmbeddingModel(OpenAIEmbeddingModel):
type: Literal["openai_embedding"]

def model_post_init(self, __context: Any) -> None:
Expand All @@ -29,7 +29,7 @@ def __str__(self):
return f"OpenAIEmbedding('{self.model}')"


class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):
class AzureOpenAIEmbeddingModel(OpenAIEmbeddingModel):
type: Literal["azure_embedding"]
endpoint: str
azure_deployment: str
Expand Down
2 changes: 1 addition & 1 deletion app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
from llm.langchain.iris_langchain_embedding import IrisLangchainEmbeddingModel
from llm.langchain.iris_langchain_embedding_model import IrisLangchainEmbeddingModel
8 changes: 4 additions & 4 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage, IrisMessageRole
from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
Expand All @@ -35,9 +35,9 @@ def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessa
class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface
request_handler: RequestHandler

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
Expand All @@ -48,7 +48,7 @@ def _generate(
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat_completion(
iris_message = self.request_handler.chat(
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
Expand Down
10 changes: 4 additions & 6 deletions app/llm/langchain/iris_langchain_completion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation

from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments


class IrisLangchainCompletionModel(BaseLLM):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandlerInterface
request_handler: RequestHandler

def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None:
def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
Expand All @@ -26,9 +26,7 @@ def _generate(
generations = []
args = CompletionArguments(stop=stop)
for prompt in prompts:
completion = self.request_handler.completion(
prompt=prompt, arguments=args, **kwargs
)
completion = self.request_handler.complete(prompt=prompt, arguments=args)
generations.append([Generation(text=completion)])
return LLMResult(generations=generations)

Expand Down
Loading
Loading