-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
565 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
exclude = | ||
.git, | ||
__pycache__, | ||
.idea | ||
per-file-ignores = | ||
# imported but unused | ||
__init__.py: F401, F403 | ||
open_ai_chat_wrapper.py: F811 | ||
open_ai_completion_wrapper.py: F811 | ||
open_ai_embedding_wrapper.py: F811 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"component:LLM": | ||
- changed-files: | ||
- any-glob-to-any-file: app/llm/** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
name: Pull Request Labeler | ||
on: pull_request_target | ||
|
||
jobs: | ||
label: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/labeler@v5 | ||
with: | ||
repo-token: "${{ secrets.GITHUB_TOKEN }}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from common.singleton import Singleton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
class Singleton(type): | ||
_instances = {} | ||
|
||
def __call__(cls, *args, **kwargs): | ||
if cls not in cls._instances: | ||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | ||
return cls._instances[cls] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from domain.message import IrisMessage, IrisMessageRole |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from enum import Enum | ||
|
||
|
||
class IrisMessageRole(Enum): | ||
USER = "user" | ||
ASSISTANT = "assistant" | ||
SYSTEM = "system" | ||
|
||
|
||
class IrisMessage: | ||
role: IrisMessageRole | ||
text: str | ||
|
||
def __init__(self, role: IrisMessageRole, text: str): | ||
self.role = role | ||
self.text = text | ||
|
||
def __str__(self): | ||
return f"IrisMessage(role={self.role.value}, text='{self.text}')" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llm.request_handler_interface import RequestHandlerInterface | ||
from llm.generation_arguments import * | ||
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from domain import IrisMessage | ||
from llm import RequestHandlerInterface, CompletionArguments | ||
from llm.llm_manager import LlmManager | ||
from llm.wrapper import ( | ||
AbstractLlmCompletionWrapper, | ||
AbstractLlmChatCompletionWrapper, | ||
AbstractLlmEmbeddingWrapper, | ||
) | ||
|
||
type BasicRequestHandlerModel = str | ||
|
||
|
||
class BasicRequestHandler(RequestHandlerInterface): | ||
model: BasicRequestHandlerModel | ||
llm_manager: LlmManager | ||
|
||
def __init__(self, model: BasicRequestHandlerModel): | ||
self.model = model | ||
self.llm_manager = LlmManager() | ||
|
||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, AbstractLlmCompletionWrapper): | ||
return llm.completion(prompt, arguments) | ||
else: | ||
raise NotImplementedError( | ||
f"The LLM {llm.__str__()} does not support completion" | ||
) | ||
|
||
def chat_completion( | ||
self, messages: list[IrisMessage], arguments: CompletionArguments | ||
) -> IrisMessage: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, AbstractLlmChatCompletionWrapper): | ||
return llm.chat_completion(messages, arguments) | ||
else: | ||
raise NotImplementedError( | ||
f"The LLM {llm.__str__()} does not support chat completion" | ||
) | ||
|
||
def create_embedding(self, text: str) -> list[float]: | ||
llm = self.llm_manager.get_llm_by_id(self.model).llm | ||
if isinstance(llm, AbstractLlmEmbeddingWrapper): | ||
return llm.create_embedding(text) | ||
else: | ||
raise NotImplementedError( | ||
f"The LLM {llm.__str__()} does not support embedding" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
class CompletionArguments: | ||
"""Arguments for the completion request""" | ||
|
||
def __init__( | ||
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None | ||
): | ||
self.max_tokens = max_tokens | ||
self.temperature = temperature | ||
self.stop = stop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
|
||
import yaml | ||
|
||
from common import Singleton | ||
from llm.wrapper import AbstractLlmWrapper | ||
|
||
|
||
# TODO: Replace with pydantic in a future PR | ||
def create_llm_wrapper(config: dict) -> AbstractLlmWrapper: | ||
if config["type"] == "openai": | ||
from llm.wrapper import OpenAICompletionWrapper | ||
|
||
return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) | ||
elif config["type"] == "azure": | ||
from llm.wrapper import AzureCompletionWrapper | ||
|
||
return AzureCompletionWrapper( | ||
id=config["id"], | ||
name=config["name"], | ||
description=config["description"], | ||
model=config["model"], | ||
endpoint=config["endpoint"], | ||
azure_deployment=config["azure_deployment"], | ||
api_version=config["api_version"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "openai_chat": | ||
from llm.wrapper import OpenAIChatCompletionWrapper | ||
|
||
return OpenAIChatCompletionWrapper( | ||
id=config["id"], | ||
name=config["name"], | ||
description=config["description"], | ||
model=config["model"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "azure_chat": | ||
from llm.wrapper import AzureChatCompletionWrapper | ||
|
||
return AzureChatCompletionWrapper( | ||
id=config["id"], | ||
name=config["name"], | ||
description=config["description"], | ||
model=config["model"], | ||
endpoint=config["endpoint"], | ||
azure_deployment=config["azure_deployment"], | ||
api_version=config["api_version"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "openai_embedding": | ||
from llm.wrapper import OpenAIEmbeddingWrapper | ||
|
||
return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) | ||
elif config["type"] == "azure_embedding": | ||
from llm.wrapper import AzureEmbeddingWrapper | ||
|
||
return AzureEmbeddingWrapper( | ||
id=config["id"], | ||
name=config["name"], | ||
description=config["description"], | ||
model=config["model"], | ||
endpoint=config["endpoint"], | ||
azure_deployment=config["azure_deployment"], | ||
api_version=config["api_version"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "ollama": | ||
from llm.wrapper import OllamaWrapper | ||
|
||
return OllamaWrapper( | ||
id=config["id"], | ||
name=config["name"], | ||
description=config["description"], | ||
model=config["model"], | ||
host=config["host"], | ||
) | ||
else: | ||
raise Exception(f"Unknown LLM type: {config['type']}") | ||
|
||
|
||
class LlmManager(metaclass=Singleton): | ||
entries: list[AbstractLlmWrapper] | ||
|
||
def __init__(self): | ||
self.entries = [] | ||
self.load_llms() | ||
|
||
def get_llm_by_id(self, llm_id): | ||
for llm in self.entries: | ||
if llm.id == llm_id: | ||
return llm | ||
|
||
def load_llms(self): | ||
path = os.environ.get("LLM_CONFIG_PATH") | ||
if not path: | ||
raise Exception("LLM_CONFIG_PATH not set") | ||
|
||
with open(path, "r") as file: | ||
loaded_llms = yaml.safe_load(file) | ||
|
||
self.entries = [create_llm_wrapper(llm) for llm in loaded_llms] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
from domain import IrisMessage | ||
from llm.generation_arguments import CompletionArguments | ||
|
||
|
||
class RequestHandlerInterface(metaclass=ABCMeta): | ||
"""Interface for the request handlers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return ( | ||
hasattr(subclass, "completion") | ||
and callable(subclass.completion) | ||
and hasattr(subclass, "chat_completion") | ||
and callable(subclass.chat_completion) | ||
and hasattr(subclass, "create_embedding") | ||
and callable(subclass.create_embedding) | ||
) | ||
|
||
@abstractmethod | ||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
"""Create a completion from the prompt""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def chat_completion( | ||
self, messages: list[any], arguments: CompletionArguments | ||
) -> [IrisMessage]: | ||
"""Create a completion from the chat messages""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def create_embedding(self, text: str) -> list[float]: | ||
"""Create an embedding from the text""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from llm.wrapper.abstract_llm_wrapper import * | ||
from llm.wrapper.open_ai_completion_wrapper import * | ||
from llm.wrapper.open_ai_chat_wrapper import * | ||
from llm.wrapper.open_ai_embedding_wrapper import * | ||
from llm.wrapper.ollama_wrapper import OllamaWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
from domain import IrisMessage | ||
from llm import CompletionArguments | ||
|
||
|
||
class AbstractLlmWrapper(metaclass=ABCMeta): | ||
"""Abstract class for the llm wrappers""" | ||
|
||
id: str | ||
name: str | ||
description: str | ||
|
||
def __init__(self, id: str, name: str, description: str): | ||
self.id = id | ||
self.name = name | ||
self.description = description | ||
|
||
|
||
class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): | ||
"""Abstract class for the llm completion wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return hasattr(subclass, "completion") and callable(subclass.completion) | ||
|
||
@abstractmethod | ||
def completion(self, prompt: str, arguments: CompletionArguments) -> str: | ||
"""Create a completion from the prompt""" | ||
raise NotImplementedError | ||
|
||
|
||
class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta): | ||
"""Abstract class for the llm chat completion wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return hasattr(subclass, "chat_completion") and callable( | ||
subclass.chat_completion | ||
) | ||
|
||
@abstractmethod | ||
def chat_completion( | ||
self, messages: list[any], arguments: CompletionArguments | ||
) -> IrisMessage: | ||
"""Create a completion from the chat messages""" | ||
raise NotImplementedError | ||
|
||
|
||
class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta): | ||
"""Abstract class for the llm embedding wrappers""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass): | ||
return hasattr(subclass, "create_embedding") and callable( | ||
subclass.create_embedding | ||
) | ||
|
||
@abstractmethod | ||
def create_embedding(self, text: str) -> list[float]: | ||
"""Create an embedding from the text""" | ||
raise NotImplementedError |
Oops, something went wrong.