-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/feature/datastore' into feature/…
…datastore
- Loading branch information
Showing
18 changed files
with
260 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from llm.request_handler_interface import RequestHandlerInterface | ||
from llm.generation_arguments import * | ||
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel | ||
from llm.request_handler_interface import RequestHandler | ||
from llm.completion_arguments import * | ||
from llm.basic_request_handler import BasicRequestHandler, DefaultModelId |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,26 @@ | ||
from domain import IrisMessage | ||
from llm import RequestHandlerInterface, CompletionArguments | ||
from llm import RequestHandler, CompletionArguments | ||
from llm.llm_manager import LlmManager | ||
from llm.wrapper.abstract_llm_wrapper import ( | ||
AbstractLlmCompletionWrapper, | ||
AbstractLlmChatCompletionWrapper, | ||
AbstractLlmEmbeddingWrapper, | ||
) | ||
|
||
type BasicRequestHandlerModel = str | ||
|
||
|
||
class BasicRequestHandler(RequestHandlerInterface): | ||
model: BasicRequestHandlerModel | ||
class BasicRequestHandler(RequestHandler): | ||
model_id: str | ||
llm_manager: LlmManager | ||
|
||
def __init__(self, model: BasicRequestHandlerModel): | ||
self.model = model | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
self.llm_manager = LlmManager() | ||
|
||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, AbstractLlmCompletionWrapper): | ||
return llm.completion(prompt, arguments) | ||
else: | ||
raise NotImplementedError( | ||
f"The LLM {llm.__str__()} does not support completion" | ||
) | ||
def complete(self, prompt: str, arguments: CompletionArguments) -> str: | ||
llm = self.llm_manager.get_by_id(self.model_id) | ||
return llm.complete(prompt, arguments) | ||
|
||
def chat_completion( | ||
def chat( | ||
self, messages: list[IrisMessage], arguments: CompletionArguments | ||
) -> IrisMessage: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, AbstractLlmChatCompletionWrapper): | ||
return llm.chat_completion(messages, arguments) | ||
else: | ||
raise NotImplementedError( | ||
f"The LLM {llm.__str__()} does not support chat completion" | ||
) | ||
llm = self.llm_manager.get_by_id(self.model_id) | ||
return llm.chat(messages, arguments) | ||
|
||
def create_embedding(self, text: str) -> list[float]: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, AbstractLlmEmbeddingWrapper): | ||
return llm.create_embedding(text) | ||
else: | ||
raise NotImplementedError( | ||
f"The LLM {llm.__str__()} does not support embedding" | ||
) | ||
def embed(self, text: str) -> list[float]: | ||
llm = self.llm_manager.get_by_id(self.model_id) | ||
return llm.embed(text) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from llm.external.model import LanguageModel | ||
from llm.external.openai_completion import ( | ||
DirectOpenAICompletionModel, | ||
AzureOpenAICompletionModel, | ||
) | ||
from llm.external.openai_chat import DirectOpenAIChatModel, AzureOpenAIChatModel | ||
from llm.external.openai_embeddings import ( | ||
DirectOpenAIEmbeddingModel, | ||
AzureOpenAIEmbeddingModel, | ||
) | ||
from llm.external.ollama import OllamaModel | ||
|
||
type AnyLLM = ( | ||
DirectOpenAICompletionModel | ||
| AzureOpenAICompletionModel | ||
| DirectOpenAIChatModel | ||
| AzureOpenAIChatModel | ||
| DirectOpenAIEmbeddingModel | ||
| AzureOpenAIEmbeddingModel | ||
| OllamaModel | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from abc import ABCMeta, abstractmethod | ||
from pydantic import BaseModel | ||
|
||
from domain import IrisMessage | ||
from llm import CompletionArguments | ||
|
||
|
||
class LanguageModel(BaseModel, metaclass=ABCMeta): | ||
"""Abstract class for the llm wrappers""" | ||
|
||
id: str | ||
name: str | ||
description: str | ||
|
||
|
||
class CompletionModel(LanguageModel, metaclass=ABCMeta): | ||
"""Abstract class for the llm completion wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass) -> bool: | ||
return hasattr(subclass, "complete") and callable(subclass.complete) | ||
|
||
@abstractmethod | ||
def complete(self, prompt: str, arguments: CompletionArguments) -> str: | ||
"""Create a completion from the prompt""" | ||
raise NotImplementedError( | ||
f"The LLM {self.__str__()} does not support completion" | ||
) | ||
|
||
|
||
class ChatModel(LanguageModel, metaclass=ABCMeta): | ||
"""Abstract class for the llm chat completion wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass) -> bool: | ||
return hasattr(subclass, "chat") and callable(subclass.chat) | ||
|
||
@abstractmethod | ||
def chat( | ||
self, messages: list[IrisMessage], arguments: CompletionArguments | ||
) -> IrisMessage: | ||
"""Create a completion from the chat messages""" | ||
raise NotImplementedError( | ||
f"The LLM {self.__str__()} does not support chat completion" | ||
) | ||
|
||
|
||
class EmbeddingModel(LanguageModel, metaclass=ABCMeta): | ||
"""Abstract class for the llm embedding wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass) -> bool: | ||
return hasattr(subclass, "embed") and callable(subclass.embed) | ||
|
||
@abstractmethod | ||
def embed(self, text: str) -> list[float]: | ||
"""Create an embedding from the text""" | ||
raise NotImplementedError( | ||
f"The LLM {self.__str__()} does not support embeddings" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llm.langchain.iris_langchain_completion_model import IrisLangchainCompletionModel | ||
from llm.langchain.iris_langchain_chat_model import IrisLangchainChatModel | ||
from llm.langchain.iris_langchain_embedding_model import IrisLangchainEmbeddingModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import List, Optional, Any | ||
|
||
from langchain_core.callbacks import CallbackManagerForLLMRun | ||
from langchain_core.language_models.chat_models import ( | ||
BaseChatModel, | ||
) | ||
from langchain_core.messages import BaseMessage | ||
from langchain_core.outputs import ChatResult | ||
from langchain_core.outputs.chat_generation import ChatGeneration | ||
|
||
from domain import IrisMessage, IrisMessageRole | ||
from llm import RequestHandler, CompletionArguments | ||
|
||
|
||
def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: | ||
role_map = { | ||
IrisMessageRole.USER: "human", | ||
IrisMessageRole.ASSISTANT: "ai", | ||
IrisMessageRole.SYSTEM: "system", | ||
} | ||
return BaseMessage(content=iris_message.text, type=role_map[iris_message.role]) | ||
|
||
|
||
def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: | ||
role_map = { | ||
"human": IrisMessageRole.USER, | ||
"ai": IrisMessageRole.ASSISTANT, | ||
"system": IrisMessageRole.SYSTEM, | ||
} | ||
return IrisMessage( | ||
text=base_message.content, role=IrisMessageRole(role_map[base_message.type]) | ||
) | ||
|
||
|
||
class IrisLangchainChatModel(BaseChatModel): | ||
"""Custom langchain chat model for our own request handler""" | ||
|
||
request_handler: RequestHandler | ||
|
||
def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: | ||
super().__init__(request_handler=request_handler, **kwargs) | ||
|
||
def _generate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any | ||
) -> ChatResult: | ||
iris_messages = [convert_base_message_to_iris_message(m) for m in messages] | ||
iris_message = self.request_handler.chat( | ||
iris_messages, CompletionArguments(stop=stop) | ||
) | ||
base_message = convert_iris_message_to_base_message(iris_message) | ||
chat_generation = ChatGeneration(message=base_message) | ||
return ChatResult(generations=[chat_generation]) | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
return "Iris" |
Oops, something went wrong.