Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: Add llm subsystem #51

Merged
merged 9 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[flake8]
max-line-length = 120
exclude =
.git,
__pycache__,
.idea
per-file-ignores =
# imported but unused
__init__.py: F401, F403
open_ai_chat_wrapper.py: F811
open_ai_completion_wrapper.py: F811
open_ai_embedding_wrapper.py: F811

3 changes: 3 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"component:LLM":
- changed-files:
- any-glob-to-any-file: app/llm/**
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.12"
cache: 'pip'

- name: Install Dependencies from requirements.txt
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/pullrequest-labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Pull Request Labeler
on: pull_request_target

jobs:
label:
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
rev: stable
hooks:
- id: black
language_version: python3.11
language_version: python3.12
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: flake8
- id: flake8
language_version: python3.12
1 change: 1 addition & 0 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from common.singleton import Singleton
7 changes: 7 additions & 0 deletions app/common/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from domain.message import IrisMessage, IrisMessageRole
19 changes: 19 additions & 0 deletions app/domain/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum


class IrisMessageRole(Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class IrisMessage:
role: IrisMessageRole
text: str

def __init__(self, role: IrisMessageRole, text: str):
self.role = role
self.text = text

def __str__(self):
return f"IrisMessage(role={self.role.value}, text='{self.text}')"
3 changes: 3 additions & 0 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.request_handler_interface import RequestHandlerInterface
from llm.generation_arguments import *
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
48 changes: 48 additions & 0 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
)

def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
9 changes: 9 additions & 0 deletions app/llm/generation_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class CompletionArguments:
"""Arguments for the completion request"""

def __init__(
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None
):
self.max_tokens = max_tokens
self.temperature = temperature
self.stop = stop
102 changes: 102 additions & 0 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper


# TODO: Replace with pydantic in a future PR
def create_llm_wrapper(config: dict) -> AbstractLlmWrapper:
if config["type"] == "openai":
from llm.wrapper import OpenAICompletionWrapper

return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure":
from llm.wrapper import AzureCompletionWrapper

return AzureCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_chat":
from llm.wrapper import OpenAIChatCompletionWrapper

return OpenAIChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
api_key=config["api_key"],
)
elif config["type"] == "azure_chat":
from llm.wrapper import AzureChatCompletionWrapper

return AzureChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_embedding":
from llm.wrapper import OpenAIEmbeddingWrapper

return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure_embedding":
from llm.wrapper import AzureEmbeddingWrapper

return AzureEmbeddingWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "ollama":
from llm.wrapper import OllamaWrapper

return OllamaWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
host=config["host"],
)
else:
raise Exception(f"Unknown LLM type: {config['type']}")


class LlmManager(metaclass=Singleton):
entries: list[AbstractLlmWrapper]

def __init__(self):
self.entries = []
self.load_llms()

def get_llm_by_id(self, llm_id):
for llm in self.entries:
if llm.id == llm_id:
return llm

def load_llms(self):
path = os.environ.get("LLM_CONFIG_PATH")
if not path:
raise Exception("LLM_CONFIG_PATH not set")

with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)

self.entries = [create_llm_wrapper(llm) for llm in loaded_llms]
36 changes: 36 additions & 0 deletions app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm.generation_arguments import CompletionArguments


class RequestHandlerInterface(metaclass=ABCMeta):
"""Interface for the request handlers"""

@classmethod
def __subclasshook__(cls, subclass):
return (
hasattr(subclass, "completion")
and callable(subclass.completion)
and hasattr(subclass, "chat_completion")
and callable(subclass.chat_completion)
and hasattr(subclass, "create_embedding")
and callable(subclass.create_embedding)
)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
"""Create a completion from the chat messages"""
raise NotImplementedError

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
5 changes: 5 additions & 0 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from llm.wrapper.abstract_llm_wrapper import *
from llm.wrapper.open_ai_completion_wrapper import *
from llm.wrapper.open_ai_chat_wrapper import *
from llm.wrapper.open_ai_embedding_wrapper import *
from llm.wrapper.ollama_wrapper import OllamaWrapper
62 changes: 62 additions & 0 deletions app/llm/wrapper/abstract_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm import CompletionArguments


class AbstractLlmWrapper(metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

id: str
name: str
description: str

def __init__(self, id: str, name: str, description: str):
self.id = id
self.name = name
self.description = description


class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "completion") and callable(subclass.completion)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError


class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "chat_completion") and callable(
subclass.chat_completion
)

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError


class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "create_embedding") and callable(
subclass.create_embedding
)

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
Loading
Loading