Skip to content

Commit

Permalink
Add IrisMessageRole and improve OpenAI wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Feb 10, 2024
1 parent 2f60e1a commit 4f5c8be
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pullrequest-labeler.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Pull Request Labeler
on: [pull_request_target]
on: pull_request_target

jobs:
label:
Expand Down
2 changes: 1 addition & 1 deletion app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from domain.message import IrisMessage
from domain.message import IrisMessage, IrisMessageRole
19 changes: 18 additions & 1 deletion app/domain/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
from enum import Enum


class IrisMessageRole(Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class IrisMessage:
def __init__(self, role, message_text):
role: IrisMessageRole
message_text: str

def __init__(self, role: IrisMessageRole, message_text: str):
self.role = role
self.message_text = message_text

def __str__(self):
return (
f"IrisMessage(role={self.role.value}, message_text='{self.message_text}')"
)
4 changes: 3 additions & 1 deletion app/llm/generation_arguments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
class CompletionArguments:
"""Arguments for the completion request"""

def __init__(self, max_tokens: int, temperature: float, stop: list[str]):
def __init__(
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None
):
self.max_tokens = max_tokens
self.temperature = temperature
self.stop = stop
81 changes: 76 additions & 5 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,97 @@
import os

import yaml

from common import Singleton
from llm.wrapper import LlmWrapperInterface


def create_llm_wrapper(config: dict) -> LlmWrapperInterface:
if config["type"] == "openai":
from llm.wrapper import OpenAICompletionWrapper

return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure":
from llm.wrapper import AzureCompletionWrapper

return AzureCompletionWrapper(
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_chat":
from llm.wrapper import OpenAIChatCompletionWrapper

return OpenAIChatCompletionWrapper(
model=config["model"], api_key=config["api_key"]
)
elif config["type"] == "azure_chat":
from llm.wrapper import AzureChatCompletionWrapper

return AzureChatCompletionWrapper(
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_embedding":
from llm.wrapper import OpenAIEmbeddingWrapper

return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure_embedding":
from llm.wrapper import AzureEmbeddingWrapper

return AzureEmbeddingWrapper(
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "ollama":
from llm.wrapper import OllamaWrapper

return OllamaWrapper(
model=config["model"],
host=config["host"],
)
else:
raise Exception(f"Unknown LLM type: {config['type']}")


class LlmManagerEntry:
id: str
llm: LlmWrapperInterface

def __init__(self, id: str, llm: LlmWrapperInterface):
self.id = id
self.llm = llm
def __init__(self, config: dict):
self.id = config["id"]
self.llm = create_llm_wrapper(config)

def __str__(self):
return f"{self.id}: {self.llm}"


class LlmManager(metaclass=Singleton):
llms: list[LlmManagerEntry]

def __init__(self):
self.llms = []
self.load_llms()

def get_llm_by_id(self, llm_id):
for llm in self.llms:
if llm.id == llm_id:
return llm

def add_llm(self, id: str, llm: LlmWrapperInterface):
self.llms.append(LlmManagerEntry(id, llm))
def load_llms(self):
path = os.environ.get("LLM_CONFIG_PATH")
if not path:
raise Exception("LLM_CONFIG_PATH not set")

with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)

self.llms = [LlmManagerEntry(llm) for llm in loaded_llms]
2 changes: 2 additions & 0 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from llm.wrapper.llm_wrapper_interface import *
from llm.wrapper.open_ai_completion_wrapper import *
from llm.wrapper.open_ai_chat_wrapper import *
from llm.wrapper.open_ai_embedding_wrapper import *
from llm.wrapper.ollama_wrapper import OllamaWrapper
3 changes: 2 additions & 1 deletion app/llm/wrapper/llm_wrapper_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod

from domain import IrisMessage
from llm import CompletionArguments

type LlmWrapperInterface = (
Expand Down Expand Up @@ -34,7 +35,7 @@ def __subclasshook__(cls, subclass):
@abstractmethod
def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
) -> IrisMessage:
"""Create a completion from the chat messages"""
raise NotImplementedError

Expand Down
9 changes: 6 additions & 3 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ollama import Client, Message

from domain import IrisMessage
from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper import (
LlmChatCompletionWrapperInterface,
Expand All @@ -11,12 +11,15 @@

def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:
return [
Message(role=message.role, content=message.message_text) for message in messages
Message(role=message.role.value, content=message.message_text)
for message in messages
]


def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=message["role"], message_text=message["content"])
return IrisMessage(
role=IrisMessageRole(message["role"]), message_text=message["content"]
)


class OllamaWrapper(
Expand Down
31 changes: 19 additions & 12 deletions app/llm/wrapper/open_ai_chat_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam

from domain import IrisMessage
from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper import LlmChatCompletionWrapperInterface

Expand All @@ -11,42 +10,50 @@ def convert_to_open_ai_messages(
messages: list[IrisMessage],
) -> list[ChatCompletionMessageParam]:
return [
ChatCompletionMessageParam(role=message.role, content=message.message_text)
{"role": message.role.value, "content": message.message_text}
for message in messages
]


def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage:
return IrisMessage(role=message.role, message_text=message.content)
# Get IrisMessageRole from the string message.role
message_role = IrisMessageRole(message.role)
return IrisMessage(role=message_role, message_text=message.content)


class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface):

def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model
class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface):

def __init__(self, client, model: str):
self.client = client
self.model = model

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
) -> IrisMessage:
response = self.client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
stop=arguments.stop,
)
return response
return convert_to_iris_message(response.choices[0].message)


class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):

def __init__(self, model: str, api_key: str):
from openai import OpenAI

client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model)

def __str__(self):
return f"OpenAIChat('{self.model}')"


class AzureChatCompletionWrapper(OpenAIChatCompletionWrapper):
class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):

def __init__(
self,
Expand Down
16 changes: 10 additions & 6 deletions app/llm/wrapper/open_ai_completion_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from llm.wrapper import LlmCompletionWrapperInterface


class OpenAICompletionWrapper(LlmCompletionWrapperInterface):

def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model
class BaseOpenAICompletionWrapper(LlmCompletionWrapperInterface):

def __init__(self, client, model: str):
self.client = client
Expand All @@ -25,11 +21,19 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any:
)
return response


class OpenAICompletionWrapper(BaseOpenAICompletionWrapper):

def __init__(self, model: str, api_key: str):
client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model)

def __str__(self):
return f"OpenAICompletion('{self.model}')"


class AzureCompletionWrapper(OpenAICompletionWrapper):
class AzureCompletionWrapper(BaseOpenAICompletionWrapper):

def __init__(
self,
Expand Down
16 changes: 10 additions & 6 deletions app/llm/wrapper/open_ai_embedding_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
)


class OpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface):

def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model
class BaseOpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface):

def __init__(self, client, model: str):
self.client = client
Expand All @@ -24,11 +20,19 @@ def create_embedding(self, text: str) -> list[float]:
)
return response.data[0].embedding


class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):

def __init__(self, model: str, api_key: str):
client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model)

def __str__(self):
return f"OpenAIEmbedding('{self.model}')"


class AzureEmbeddingWrapper(OpenAIEmbeddingWrapper):
class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper):

def __init__(
self,
Expand Down

0 comments on commit 4f5c8be

Please sign in to comment.