Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: Use pydantic for config parsing #52

Merged
merged 11 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[flake8]
max-line-length = 120
exclude =
.git,
__pycache__,
.idea
per-file-ignores =
# imported but unused
__init__.py: F401, F403
open_ai_chat_wrapper.py: F811
open_ai_completion_wrapper.py: F811
open_ai_embedding_wrapper.py: F811

3 changes: 3 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"component:LLM":
- changed-files:
- any-glob-to-any-file: app/llm/**
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.12"
cache: 'pip'

- name: Install Dependencies from requirements.txt
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/pullrequest-labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Pull Request Labeler
on: pull_request_target

jobs:
label:
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
rev: stable
hooks:
- id: black
language_version: python3.11
language_version: python3.12
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: flake8
- id: flake8
language_version: python3.12
1 change: 1 addition & 0 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from common.singleton import Singleton
7 changes: 7 additions & 0 deletions app/common/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from domain.message import IrisMessage, IrisMessageRole
19 changes: 19 additions & 0 deletions app/domain/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum


class IrisMessageRole(Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class IrisMessage:
role: IrisMessageRole
text: str

def __init__(self, role: IrisMessageRole, text: str):
self.role = role
self.text = text

def __str__(self):
return f"IrisMessage(role={self.role.value}, text='{self.text}')"
3 changes: 3 additions & 0 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llm.request_handler_interface import RequestHandlerInterface
from llm.generation_arguments import *
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
48 changes: 48 additions & 0 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)

type BasicRequestHandlerModel = str


class BasicRequestHandler(RequestHandlerInterface):
model: BasicRequestHandlerModel
llm_manager: LlmManager

def __init__(self, model: BasicRequestHandlerModel):
self.model = model
self.llm_manager = LlmManager()

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmCompletionWrapper):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
)

def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmChatCompletionWrapper):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, AbstractLlmEmbeddingWrapper):
return llm.create_embedding(text)
else:
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
9 changes: 9 additions & 0 deletions app/llm/generation_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class CompletionArguments:
"""Arguments for the completion request"""

def __init__(
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None
):
self.max_tokens = max_tokens
self.temperature = temperature
self.stop = stop
36 changes: 36 additions & 0 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

from pydantic import BaseModel, Field

import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper, LlmWrapper


# Small workaround to get pydantic discriminators working
class LlmList(BaseModel):
llms: list[LlmWrapper] = Field(discriminator="type")


class LlmManager(metaclass=Singleton):
entries: list[AbstractLlmWrapper]

def __init__(self):
self.entries = []
self.load_llms()

def get_llm_by_id(self, llm_id):
for llm in self.entries:
if llm.id == llm_id:
return llm

def load_llms(self):
path = os.environ.get("LLM_CONFIG_PATH")
if not path:
raise Exception("LLM_CONFIG_PATH not set")

with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)

self.entries = LlmList.parse_obj({"llms": loaded_llms}).llms
36 changes: 36 additions & 0 deletions app/llm/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm.generation_arguments import CompletionArguments


class RequestHandlerInterface(metaclass=ABCMeta):
"""Interface for the request handlers"""

@classmethod
def __subclasshook__(cls, subclass):
return (
hasattr(subclass, "completion")
and callable(subclass.completion)
and hasattr(subclass, "chat_completion")
and callable(subclass.chat_completion)
and hasattr(subclass, "create_embedding")
and callable(subclass.create_embedding)
)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> [IrisMessage]:
"""Create a completion from the chat messages"""
raise NotImplementedError

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
24 changes: 24 additions & 0 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper
from llm.wrapper.open_ai_completion_wrapper import (
OpenAICompletionWrapper,
AzureCompletionWrapper,
)
from llm.wrapper.open_ai_chat_wrapper import (
OpenAIChatCompletionWrapper,
AzureChatCompletionWrapper,
)
from llm.wrapper.open_ai_embedding_wrapper import (
OpenAIEmbeddingWrapper,
AzureEmbeddingWrapper,
)
from llm.wrapper.ollama_wrapper import OllamaWrapper

type LlmWrapper = (
OpenAICompletionWrapper
| AzureCompletionWrapper
| OpenAIChatCompletionWrapper
| AzureChatCompletionWrapper
| OpenAIEmbeddingWrapper
| AzureEmbeddingWrapper
| OllamaWrapper
)
58 changes: 58 additions & 0 deletions app/llm/wrapper/abstract_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from domain import IrisMessage
from llm import CompletionArguments


class AbstractLlmWrapper(BaseModel, metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

id: str
name: str
description: str


class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "completion") and callable(subclass.completion)

@abstractmethod
def completion(self, prompt: str, arguments: CompletionArguments) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError


class AbstractLlmChatCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm chat completion wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "chat_completion") and callable(
subclass.chat_completion
)

@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError


class AbstractLlmEmbeddingWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm embedding wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "create_embedding") and callable(
subclass.create_embedding
)

@abstractmethod
def create_embedding(self, text: str) -> list[float]:
"""Create an embedding from the text"""
raise NotImplementedError
54 changes: 54 additions & 0 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Literal, Any

from ollama import Client, Message

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmChatCompletionWrapper,
AbstractLlmCompletionWrapper,
AbstractLlmEmbeddingWrapper,
)


def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
return [
Message(role=message.role.value, content=message.text) for message in messages
]


def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"])


class OllamaWrapper(
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
):
type: Literal["ollama"]
model: str
host: str
_client: Client

def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
response = self._client.generate(model=self.model, prompt=prompt)
return response["response"]

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
response = self._client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
)
return convert_to_iris_message(response["message"])

def create_embedding(self, text: str) -> list[float]:
response = self._client.embeddings(model=self.model, prompt=text)
return list(response)

def __str__(self):
return f"Ollama('{self.model}')"
Loading
Loading