Skip to content

Commit

Permalink
Merge branch 'main' into feature/datastore
Browse files Browse the repository at this point in the history
  • Loading branch information
yassinsws authored Feb 20, 2024
2 parents 128ea40 + db3c066 commit 0818109
Show file tree
Hide file tree
Showing 18 changed files with 260 additions and 169 deletions.
6 changes: 3 additions & 3 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from llm.request_handler_interface import RequestHandlerInterface
from llm.generation_arguments import *
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
from llm.request_handler_interface import RequestHandler
from llm.completion_arguments import *
from llm.basic_request_handler import BasicRequestHandler, DefaultModelId
50 changes: 14 additions & 36 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm import RequestHandler, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
class BasicRequestHandler(RequestHandler):
model_id: str
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
def __init__(self, model_id: str):
self.model_id = model_id
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
)
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.complete(prompt, arguments)

def chat_completion(
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)
llm = self.llm_manager.get_by_id(self.model_id)
return llm.chat(messages, arguments)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
def embed(self, text: str) -> list[float]:
llm = self.llm_manager.get_by_id(self.model_id)
return llm.embed(text)
File renamed without changes.
21 changes: 21 additions & 0 deletions app/llm/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from llm.external.model import LanguageModel
from llm.external.openai_completion import (
DirectOpenAICompletionModel,
AzureOpenAICompletionModel,
)
from llm.external.openai_chat import DirectOpenAIChatModel, AzureOpenAIChatModel
from llm.external.openai_embeddings import (
DirectOpenAIEmbeddingModel,
AzureOpenAIEmbeddingModel,
)
from llm.external.ollama import OllamaModel

type AnyLLM = (
DirectOpenAICompletionModel
| AzureOpenAICompletionModel
| DirectOpenAIChatModel
| AzureOpenAIChatModel
| DirectOpenAIEmbeddingModel
| AzureOpenAIEmbeddingModel
| OllamaModel
)
60 changes: 60 additions & 0 deletions app/llm/external/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from domain import IrisMessage
from llm import CompletionArguments


class LanguageModel(BaseModel, metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

id: str
name: str
description: str


class CompletionModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass) -> bool:
return hasattr(subclass, "complete") and callable(subclass.complete)

@abstractmethod
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support completion"
)


class ChatModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass) -> bool:
return hasattr(subclass, "chat") and callable(subclass.chat)

@abstractmethod
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support chat completion"
)


class EmbeddingModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass) -> bool:
return hasattr(subclass, "embed") and callable(subclass.embed)

@abstractmethod
def embed(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support embeddings"
)
24 changes: 10 additions & 14 deletions app/llm/wrapper/ollama_wrapper.py → app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmChatCompletionWrapper,
AbstractLlmCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)
from llm.external.model import ChatModel, CompletionModel, EmbeddingModel


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
Expand All @@ -21,10 +17,10 @@ def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"])


class OllamaWrapper(
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
class OllamaModel(
CompletionModel,
ChatModel,
EmbeddingModel,
):
type: Literal["ollama"]
model: str
Expand All @@ -34,19 +30,19 @@ class OllamaWrapper(
def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
response = self._client.generate(model=self.model, prompt=prompt)
return response["response"]

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
response = self._client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
)
return convert_to_iris_message(response["message"])

def create_embedding(self, text: str) -> list[float]:
def embed(self, text: str) -> list[float]:
response = self._client.embeddings(model=self.model, prompt=text)
return list(response)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Literal, Any

from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper
from llm.external.model import ChatModel


def convert_to_open_ai_messages(
Expand All @@ -22,13 +23,13 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage:
return IrisMessage(role=message_role, text=message.content)


class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper):
class OpenAIChatModel(ChatModel):
model: str
api_key: str
_client: OpenAI

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
response = self._client.chat.completions.create(
model=self.model,
Expand All @@ -40,7 +41,7 @@ def chat_completion(
return convert_to_iris_message(response.choices[0].message)


class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):
class DirectOpenAIChatModel(OpenAIChatModel):
type: Literal["openai_chat"]

def model_post_init(self, __context: Any) -> None:
Expand All @@ -50,7 +51,7 @@ def __str__(self):
return f"OpenAIChat('{self.model}')"


class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):
class AzureOpenAIChatModel(OpenAIChatModel):
type: Literal["azure_chat"]
endpoint: str
azure_deployment: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from openai.lib.azure import AzureOpenAI

from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper
from llm.external.model import CompletionModel


class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper):
class OpenAICompletionModel(CompletionModel):
model: str
api_key: str
_client: OpenAI

def completion(self, prompt: str, arguments: CompletionArguments) -> any:
def complete(self, prompt: str, arguments: CompletionArguments) -> any:
response = self._client.completions.create(
model=self.model,
prompt=prompt,
Expand All @@ -22,7 +22,7 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any:
return response


class OpenAICompletionWrapper(BaseOpenAICompletionWrapper):
class DirectOpenAICompletionModel(OpenAICompletionModel):
type: Literal["openai_completion"]

def model_post_init(self, __context: Any) -> None:
Expand All @@ -32,7 +32,7 @@ def __str__(self):
return f"OpenAICompletion('{self.model}')"


class AzureCompletionWrapper(BaseOpenAICompletionWrapper):
class AzureOpenAICompletionModel(OpenAICompletionModel):
type: Literal["azure_completion"]
endpoint: str
azure_deployment: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI

from llm.wrapper.abstract_llm_wrapper import AbstractLlmEmbeddingWrapper
from llm.external.model import EmbeddingModel


class BaseOpenAIEmbeddingWrapper(AbstractLlmEmbeddingWrapper):
class OpenAIEmbeddingModel(EmbeddingModel):
model: str
api_key: str
_client: OpenAI

def create_embedding(self, text: str) -> list[float]:
def embed(self, text: str) -> list[float]:
response = self._client.embeddings.create(
model=self.model,
input=text,
Expand All @@ -19,7 +19,7 @@ def create_embedding(self, text: str) -> list[float]:
return response.data[0].embedding


class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):
class DirectOpenAIEmbeddingModel(OpenAIEmbeddingModel):
type: Literal["openai_embedding"]

def model_post_init(self, __context: Any) -> None:
Expand All @@ -29,7 +29,7 @@ def __str__(self):
return f"OpenAIEmbedding('{self.model}')"


class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):
class AzureOpenAIEmbeddingModel(OpenAIEmbeddingModel):
type: Literal["azure_embedding"]
endpoint: str
azure_deployment: str
Expand Down
3 changes: 3 additions & 0 deletions app/llm/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel
from llm.langchain.iris_langchain_embedding_model import IrisLangchainEmbeddingModel
60 changes: 60 additions & 0 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration

from domain import IrisMessage, IrisMessageRole
from llm import RequestHandler, CompletionArguments


def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage:
role_map = {
IrisMessageRole.USER: "human",
IrisMessageRole.ASSISTANT: "ai",
IrisMessageRole.SYSTEM: "system",
}
return BaseMessage(content=iris_message.text, type=role_map[iris_message.role])


def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage:
role_map = {
"human": IrisMessageRole.USER,
"ai": IrisMessageRole.ASSISTANT,
"system": IrisMessageRole.SYSTEM,
}
return IrisMessage(
text=base_message.content, role=IrisMessageRole(role_map[base_message.type])
)


class IrisLangchainChatModel(BaseChatModel):
"""Custom langchain chat model for our own request handler"""

request_handler: RequestHandler

def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None:
super().__init__(request_handler=request_handler, **kwargs)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
iris_messages = [convert_base_message_to_iris_message(m) for m in messages]
iris_message = self.request_handler.chat(
iris_messages, CompletionArguments(stop=stop)
)
base_message = convert_iris_message_to_base_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
return ChatResult(generations=[chat_generation])

@property
def _llm_type(self) -> str:
return "Iris"
Loading

0 comments on commit 0818109

Please sign in to comment.