Skip to content

Commit

Permalink
LLM: Add llm subsystem (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus authored Feb 12, 2024
1 parent 281b411 commit 94f786a
Show file tree
Hide file tree
Showing 20 changed files with 565 additions and 3 deletions.
13 changes: 13 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[flake8]
max-line-length = 120
exclude =
.git,
__pycache__,
.idea
per-file-ignores =
# imported but unused
__init__.py: F401, F403
open_ai_chat_wrapper.py: F811
open_ai_completion_wrapper.py: F811
open_ai_embedding_wrapper.py: F811

3 changes: 3 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"component:LLM":
- changed-files:
- any-glob-to-any-file: app/llm/**
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.12"
cache: 'pip'

- name: Install Dependencies from requirements.txt
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/pullrequest-labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Pull Request Labeler
on: pull_request_target

jobs:
label:
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
rev: stable
hooks:
- id: black
language_version: python3.11
language_version: python3.12
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: flake8
- id: flake8
language_version: python3.12
1 change: 1 addition & 0 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from common.singleton import Singleton
7 changes: 7 additions & 0 deletions app/common/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from domain.message import IrisMessage, IrisMessageRole
19 changes: 19 additions & 0 deletions app/domain/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum


class IrisMessageRole(Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class IrisMessage:
role: IrisMessageRole
text: str

def __init__(self, role: IrisMessageRole, text: str):
self.role = role
self.text = text

def __str__(self):
return f"IrisMessage(role={self.role.value}, text='{self.text}')"
3 changes: 3 additions & 0 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.request_handler_interface import RequestHandlerInterface
from llm.generation_arguments import *
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
48 changes: 48 additions & 0 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
)

def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
9 changes: 9 additions & 0 deletions app/llm/generation_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class CompletionArguments:
"""Arguments for the completion request"""

def __init__(
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None
):
self.max_tokens = max_tokens
self.temperature = temperature
self.stop = stop
102 changes: 102 additions & 0 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper


# TODO: Replace with pydantic in a future PR
def create_llm_wrapper(config: dict) -> AbstractLlmWrapper:
if config["type"] == "openai":
from llm.wrapper import OpenAICompletionWrapper

return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure":
from llm.wrapper import AzureCompletionWrapper

return AzureCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_chat":
from llm.wrapper import OpenAIChatCompletionWrapper

return OpenAIChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
api_key=config["api_key"],
)
elif config["type"] == "azure_chat":
from llm.wrapper import AzureChatCompletionWrapper

return AzureChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_embedding":
from llm.wrapper import OpenAIEmbeddingWrapper

return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure_embedding":
from llm.wrapper import AzureEmbeddingWrapper

return AzureEmbeddingWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "ollama":
from llm.wrapper import OllamaWrapper

return OllamaWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
host=config["host"],
)
else:
raise Exception(f"Unknown LLM type: {config['type']}")


class LlmManager(metaclass=Singleton):
entries: list[AbstractLlmWrapper]

def __init__(self):
self.entries = []
self.load_llms()

def get_llm_by_id(self, llm_id):
for llm in self.entries:
if llm.id == llm_id:
return llm

def load_llms(self):
path = os.environ.get("LLM_CONFIG_PATH")
if not path:
raise Exception("LLM_CONFIG_PATH not set")

with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)

self.entries = [create_llm_wrapper(llm) for llm in loaded_llms]
36 changes: 36 additions & 0 deletions app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm.generation_arguments import CompletionArguments


class RequestHandlerInterface(metaclass=ABCMeta):
"""Interface for the request handlers"""

@classmethod
def __subclasshook__(cls, subclass):
return (
hasattr(subclass, "completion")
and callable(subclass.completion)
and hasattr(subclass, "chat_completion")
and callable(subclass.chat_completion)
and hasattr(subclass, "create_embedding")
and callable(subclass.create_embedding)
)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
"""Create a completion from the chat messages"""
raise NotImplementedError

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
5 changes: 5 additions & 0 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from llm.wrapper.abstract_llm_wrapper import *
from llm.wrapper.open_ai_completion_wrapper import *
from llm.wrapper.open_ai_chat_wrapper import *
from llm.wrapper.open_ai_embedding_wrapper import *
from llm.wrapper.ollama_wrapper import OllamaWrapper
62 changes: 62 additions & 0 deletions app/llm/wrapper/abstract_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm import CompletionArguments


class AbstractLlmWrapper(metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

id: str
name: str
description: str

def __init__(self, id: str, name: str, description: str):
self.id = id
self.name = name
self.description = description


class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "completion") and callable(subclass.completion)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError


class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "chat_completion") and callable(
subclass.chat_completion
)

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError


class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "create_embedding") and callable(
subclass.create_embedding
)

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
Loading

0 comments on commit 94f786a

Please sign in to comment.