-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
288 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from singleton import Singleton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
class Singleton(type): | ||
_instances = {} | ||
|
||
def __call__(cls, *args, **kwargs): | ||
if cls not in cls._instances: | ||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | ||
return cls._instances[cls] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from message import IrisMessage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
class IrisMessage: | ||
def __init__(self, role, message_text): | ||
self.role = role | ||
self.message_text = message_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from generation_arguments import CompletionArguments | ||
from request_handler_interface import RequestHandlerInterface | ||
from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel | ||
from llm_manager import LlmManager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from domain import IrisMessage | ||
from llm import LlmManager | ||
from llm import RequestHandlerInterface, CompletionArguments | ||
from llm.wrapper import LlmCompletionWrapperInterface, LlmChatCompletionWrapperInterface, LlmEmbeddingWrapperInterface | ||
|
||
type BasicRequestHandlerModel = str | ||
|
||
|
||
class BasicRequestHandler(RequestHandlerInterface): | ||
model: BasicRequestHandlerModel | ||
llm_manager: LlmManager | ||
|
||
def __init__(self, model: BasicRequestHandlerModel): | ||
self.model = model | ||
self.llm_manager = LlmManager() | ||
|
||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, LlmCompletionWrapperInterface): | ||
return llm.completion(prompt, arguments) | ||
else: | ||
raise NotImplementedError | ||
|
||
def chat_completion(self, messages: list[IrisMessage], arguments: CompletionArguments) -> IrisMessage: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, LlmChatCompletionWrapperInterface): | ||
return llm.chat_completion(messages, arguments) | ||
else: | ||
raise NotImplementedError | ||
|
||
def create_embedding(self, text: str) -> list[float]: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, LlmEmbeddingWrapperInterface): | ||
return llm.create_embedding(text) | ||
else: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
class CompletionArguments: | ||
"""Arguments for the completion request""" | ||
|
||
def __init__(self, max_tokens: int, temperature: float, stop: list[str]): | ||
self.max_tokens = max_tokens | ||
self.temperature = temperature | ||
self.stop = stop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from common import Singleton | ||
from llm.wrapper import LlmWrapperInterface | ||
|
||
|
||
class LlmManagerEntry: | ||
id: str | ||
llm: LlmWrapperInterface | ||
|
||
def __init__(self, id: str, llm: LlmWrapperInterface): | ||
self.id = id | ||
self.llm = llm | ||
|
||
|
||
class LlmManager(metaclass=Singleton): | ||
llms: list[LlmManagerEntry] | ||
|
||
def __init__(self): | ||
self.llms = [] | ||
|
||
def get_llm_by_id(self, llm_id): | ||
for llm in self.llms: | ||
if llm.id == llm_id: | ||
return llm | ||
|
||
def add_llm(self, id: str, llm: LlmWrapperInterface): | ||
self.llms.append(LlmManagerEntry(id, llm)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
from domain import IrisMessage | ||
from llm.generation_arguments import CompletionArguments | ||
|
||
|
||
class RequestHandlerInterface(metaclass=ABCMeta): | ||
"""Interface for the request handlers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return ( | ||
hasattr(subclass, "completion") | ||
and callable(subclass.completion) | ||
and hasattr(subclass, "chat_completion") | ||
and callable(subclass.chat_completion) | ||
and hasattr(subclass, "create_embedding") | ||
and callable(subclass.create_embedding) | ||
) | ||
|
||
@abstractmethod | ||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
"""Create a completion from the prompt""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def chat_completion( | ||
self, messages: list[any], arguments: CompletionArguments | ||
) -> [IrisMessage]: | ||
"""Create a completion from the chat messages""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def create_embedding(self, text: str) -> list[float]: | ||
"""Create an embedding from the text""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llm_wrapper_interface import * | ||
from open_ai_chat_wrapper import * | ||
from ollama_wrapper import OllamaWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from openai.lib.azure import AzureOpenAI | ||
|
||
from llm import CompletionArguments | ||
from llm.wrapper import LlmChatCompletionWrapperInterface, convert_to_open_ai_messages | ||
|
||
|
||
class AzureChatCompletionWrapper(LlmChatCompletionWrapperInterface): | ||
|
||
def __init__( | ||
self, | ||
model: str, | ||
endpoint: str, | ||
azure_deployment: str, | ||
api_version: str, | ||
api_key: str, | ||
): | ||
self.client = AzureOpenAI( | ||
azure_endpoint=endpoint, | ||
azure_deployment=azure_deployment, | ||
api_version=api_version, | ||
api_key=api_key, | ||
) | ||
self.model = model | ||
|
||
def chat_completion( | ||
self, messages: list[any], arguments: CompletionArguments | ||
) -> any: | ||
response = self.client.chat.completions.create( | ||
model=self.model, | ||
messages=convert_to_open_ai_messages(messages), | ||
) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
from llm import CompletionArguments | ||
|
||
type LlmWrapperInterface = LlmCompletionWrapperInterface | LlmChatCompletionWrapperInterface | LlmEmbeddingWrapperInterface | ||
|
||
|
||
class LlmCompletionWrapperInterface(metaclass=ABCMeta): | ||
"""Interface for the llm completion wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return (hasattr(subclass, 'completion') and | ||
callable(subclass.completion)) | ||
|
||
@abstractmethod | ||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
"""Create a completion from the prompt""" | ||
raise NotImplementedError | ||
|
||
|
||
class LlmChatCompletionWrapperInterface(metaclass=ABCMeta): | ||
"""Interface for the llm chat completion wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return (hasattr(subclass, 'chat_completion') and | ||
callable(subclass.chat_completion)) | ||
|
||
@abstractmethod | ||
def chat_completion(self, messages: list[any], arguments: CompletionArguments) -> any: | ||
"""Create a completion from the chat messages""" | ||
raise NotImplementedError | ||
|
||
|
||
class LlmEmbeddingWrapperInterface(metaclass=ABCMeta): | ||
"""Interface for the llm embedding wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return (hasattr(subclass, 'create_embedding') and | ||
callable(subclass.create_embedding)) | ||
|
||
@abstractmethod | ||
def create_embedding(self, text: str) -> list[float]: | ||
"""Create an embedding from the text""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from ollama import Client, Message | ||
from openai import OpenAI | ||
from openai.types.chat import ChatCompletionMessageParam | ||
|
||
from domain import IrisMessage | ||
from llm import CompletionArguments | ||
from llm.wrapper import ( | ||
LlmChatCompletionWrapperInterface, | ||
LlmCompletionWrapperInterface, | ||
LlmEmbeddingWrapperInterface, | ||
) | ||
|
||
|
||
def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: | ||
return [ | ||
Message(role=message.role, content=message.message_text) for message in messages | ||
] | ||
|
||
|
||
def convert_to_iris_message(message: Message) -> IrisMessage: | ||
return IrisMessage(role=message.role, message_text=message.content) | ||
|
||
|
||
class OllamaWrapper( | ||
LlmCompletionWrapperInterface, | ||
LlmChatCompletionWrapperInterface, | ||
LlmEmbeddingWrapperInterface, | ||
): | ||
|
||
def __init__(self, model: str, host: str): | ||
self.client = Client(host=host) # TODO: Add authentication (httpx auth?) | ||
self.model = model | ||
|
||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
response = self.client.generate(model=self.model, prompt=prompt) | ||
return response["response"] | ||
|
||
def chat_completion( | ||
self, messages: list[any], arguments: CompletionArguments | ||
) -> any: | ||
response = self.client.chat( | ||
model=self.model, messages=convert_to_ollama_messages(messages) | ||
) | ||
return convert_to_iris_message(response["message"]) | ||
|
||
def create_embedding(self, text: str) -> list[float]: | ||
response = self.client.embeddings(model=self.model, prompt=text) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from openai import OpenAI | ||
from openai.types.chat import ChatCompletionMessageParam | ||
|
||
from domain import IrisMessage | ||
from llm import CompletionArguments | ||
from llm.wrapper import LlmChatCompletionWrapperInterface | ||
|
||
|
||
def convert_to_open_ai_messages( | ||
messages: list[IrisMessage], | ||
) -> list[ChatCompletionMessageParam]: | ||
return [ | ||
ChatCompletionMessageParam(role=message.role, content=message.message_text) | ||
for message in messages | ||
] | ||
|
||
|
||
def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: | ||
return IrisMessage(role=message.role, message_text=message.content) | ||
|
||
|
||
class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): | ||
|
||
def __init__(self, model: str, api_key: str): | ||
self.client = OpenAI(api_key=api_key) | ||
self.model = model | ||
|
||
def chat_completion( | ||
self, messages: list[any], arguments: CompletionArguments | ||
) -> any: | ||
response = self.client.chat.completions.create( | ||
model=self.model, | ||
messages=convert_to_open_ai_messages(messages), | ||
) | ||
return response |