diff --git a/.github/workflows/pullrequest-labeler.yml b/.github/workflows/pullrequest-labeler.yml index 90c20ceb..f7739956 100644 --- a/.github/workflows/pullrequest-labeler.yml +++ b/.github/workflows/pullrequest-labeler.yml @@ -1,5 +1,5 @@ name: Pull Request Labeler -on: [pull_request_target] +on: pull_request_target jobs: label: diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 270c228a..b73080e7 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -1 +1 @@ -from domain.message import IrisMessage +from domain.message import IrisMessage, IrisMessageRole diff --git a/app/domain/message.py b/app/domain/message.py index b5fe26f7..960750a9 100644 --- a/app/domain/message.py +++ b/app/domain/message.py @@ -1,4 +1,21 @@ +from enum import Enum + + +class IrisMessageRole(Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + class IrisMessage: - def __init__(self, role, message_text): + role: IrisMessageRole + message_text: str + + def __init__(self, role: IrisMessageRole, message_text: str): self.role = role self.message_text = message_text + + def __str__(self): + return ( + f"IrisMessage(role={self.role.value}, message_text='{self.message_text}')" + ) diff --git a/app/llm/generation_arguments.py b/app/llm/generation_arguments.py index 37a4af19..a540e144 100644 --- a/app/llm/generation_arguments.py +++ b/app/llm/generation_arguments.py @@ -1,7 +1,9 @@ class CompletionArguments: """Arguments for the completion request""" - def __init__(self, max_tokens: int, temperature: float, stop: list[str]): + def __init__( + self, max_tokens: int = None, temperature: float = None, stop: list[str] = None + ): self.max_tokens = max_tokens self.temperature = temperature self.stop = stop diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index 8dd649d8..bc6de680 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -1,14 +1,77 @@ +import os + +import yaml + from common import Singleton from llm.wrapper import LlmWrapperInterface +def create_llm_wrapper(config: dict) -> LlmWrapperInterface: + if config["type"] == "openai": + from llm.wrapper import OpenAICompletionWrapper + + return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure": + from llm.wrapper import AzureCompletionWrapper + + return AzureCompletionWrapper( + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_chat": + from llm.wrapper import OpenAIChatCompletionWrapper + + return OpenAIChatCompletionWrapper( + model=config["model"], api_key=config["api_key"] + ) + elif config["type"] == "azure_chat": + from llm.wrapper import AzureChatCompletionWrapper + + return AzureChatCompletionWrapper( + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "openai_embedding": + from llm.wrapper import OpenAIEmbeddingWrapper + + return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) + elif config["type"] == "azure_embedding": + from llm.wrapper import AzureEmbeddingWrapper + + return AzureEmbeddingWrapper( + model=config["model"], + endpoint=config["endpoint"], + azure_deployment=config["azure_deployment"], + api_version=config["api_version"], + api_key=config["api_key"], + ) + elif config["type"] == "ollama": + from llm.wrapper import OllamaWrapper + + return OllamaWrapper( + model=config["model"], + host=config["host"], + ) + else: + raise Exception(f"Unknown LLM type: {config['type']}") + + class LlmManagerEntry: id: str llm: LlmWrapperInterface - def __init__(self, id: str, llm: LlmWrapperInterface): - self.id = id - self.llm = llm + def __init__(self, config: dict): + self.id = config["id"] + self.llm = create_llm_wrapper(config) + + def __str__(self): + return f"{self.id}: {self.llm}" class LlmManager(metaclass=Singleton): @@ -16,11 +79,19 @@ class LlmManager(metaclass=Singleton): def __init__(self): self.llms = [] + self.load_llms() def get_llm_by_id(self, llm_id): for llm in self.llms: if llm.id == llm_id: return llm - def add_llm(self, id: str, llm: LlmWrapperInterface): - self.llms.append(LlmManagerEntry(id, llm)) + def load_llms(self): + path = os.environ.get("LLM_CONFIG_PATH") + if not path: + raise Exception("LLM_CONFIG_PATH not set") + + with open(path, "r") as file: + loaded_llms = yaml.safe_load(file) + + self.llms = [LlmManagerEntry(llm) for llm in loaded_llms] diff --git a/app/llm/wrapper/__init__.py b/app/llm/wrapper/__init__.py index 6ddf4569..4364afa0 100644 --- a/app/llm/wrapper/__init__.py +++ b/app/llm/wrapper/__init__.py @@ -1,3 +1,5 @@ from llm.wrapper.llm_wrapper_interface import * +from llm.wrapper.open_ai_completion_wrapper import * from llm.wrapper.open_ai_chat_wrapper import * +from llm.wrapper.open_ai_embedding_wrapper import * from llm.wrapper.ollama_wrapper import OllamaWrapper diff --git a/app/llm/wrapper/llm_wrapper_interface.py b/app/llm/wrapper/llm_wrapper_interface.py index 73aade5d..b1e79acb 100644 --- a/app/llm/wrapper/llm_wrapper_interface.py +++ b/app/llm/wrapper/llm_wrapper_interface.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod +from domain import IrisMessage from llm import CompletionArguments type LlmWrapperInterface = ( @@ -34,7 +35,7 @@ def __subclasshook__(cls, subclass): @abstractmethod def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> any: + ) -> IrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/app/llm/wrapper/ollama_wrapper.py b/app/llm/wrapper/ollama_wrapper.py index 9dc8131a..8d0cd397 100644 --- a/app/llm/wrapper/ollama_wrapper.py +++ b/app/llm/wrapper/ollama_wrapper.py @@ -1,6 +1,6 @@ from ollama import Client, Message -from domain import IrisMessage +from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments from llm.wrapper import ( LlmChatCompletionWrapperInterface, @@ -11,12 +11,15 @@ def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: return [ - Message(role=message.role, content=message.message_text) for message in messages + Message(role=message.role.value, content=message.message_text) + for message in messages ] def convert_to_iris_message(message: Message) -> IrisMessage: - return IrisMessage(role=message["role"], message_text=message["content"]) + return IrisMessage( + role=IrisMessageRole(message["role"]), message_text=message["content"] + ) class OllamaWrapper( diff --git a/app/llm/wrapper/open_ai_chat_wrapper.py b/app/llm/wrapper/open_ai_chat_wrapper.py index a21a9951..be1016a0 100644 --- a/app/llm/wrapper/open_ai_chat_wrapper.py +++ b/app/llm/wrapper/open_ai_chat_wrapper.py @@ -1,8 +1,7 @@ -from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam -from domain import IrisMessage +from domain import IrisMessage, IrisMessageRole from llm import CompletionArguments from llm.wrapper import LlmChatCompletionWrapperInterface @@ -11,20 +10,18 @@ def convert_to_open_ai_messages( messages: list[IrisMessage], ) -> list[ChatCompletionMessageParam]: return [ - ChatCompletionMessageParam(role=message.role, content=message.message_text) + {"role": message.role.value, "content": message.message_text} for message in messages ] def convert_to_iris_message(message: ChatCompletionMessageParam) -> IrisMessage: - return IrisMessage(role=message.role, message_text=message.content) + # Get IrisMessageRole from the string message.role + message_role = IrisMessageRole(message.role) + return IrisMessage(role=message_role, message_text=message.content) -class OpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): - - def __init__(self, model: str, api_key: str): - self.client = OpenAI(api_key=api_key) - self.model = model +class BaseOpenAIChatCompletionWrapper(LlmChatCompletionWrapperInterface): def __init__(self, client, model: str): self.client = client @@ -32,7 +29,7 @@ def __init__(self, client, model: str): def chat_completion( self, messages: list[any], arguments: CompletionArguments - ) -> any: + ) -> IrisMessage: response = self.client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), @@ -40,13 +37,23 @@ def chat_completion( max_tokens=arguments.max_tokens, stop=arguments.stop, ) - return response + return convert_to_iris_message(response.choices[0].message) + + +class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): + + def __init__(self, model: str, api_key: str): + from openai import OpenAI + + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model) def __str__(self): return f"OpenAIChat('{self.model}')" -class AzureChatCompletionWrapper(OpenAIChatCompletionWrapper): +class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper): def __init__( self, diff --git a/app/llm/wrapper/open_ai_completion_wrapper.py b/app/llm/wrapper/open_ai_completion_wrapper.py index 1d2eded3..94ba1ee2 100644 --- a/app/llm/wrapper/open_ai_completion_wrapper.py +++ b/app/llm/wrapper/open_ai_completion_wrapper.py @@ -5,11 +5,7 @@ from llm.wrapper import LlmCompletionWrapperInterface -class OpenAICompletionWrapper(LlmCompletionWrapperInterface): - - def __init__(self, model: str, api_key: str): - self.client = OpenAI(api_key=api_key) - self.model = model +class BaseOpenAICompletionWrapper(LlmCompletionWrapperInterface): def __init__(self, client, model: str): self.client = client @@ -25,11 +21,19 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any: ) return response + +class OpenAICompletionWrapper(BaseOpenAICompletionWrapper): + + def __init__(self, model: str, api_key: str): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model) + def __str__(self): return f"OpenAICompletion('{self.model}')" -class AzureCompletionWrapper(OpenAICompletionWrapper): +class AzureCompletionWrapper(BaseOpenAICompletionWrapper): def __init__( self, diff --git a/app/llm/wrapper/open_ai_embedding_wrapper.py b/app/llm/wrapper/open_ai_embedding_wrapper.py index 4983d3ee..726fb272 100644 --- a/app/llm/wrapper/open_ai_embedding_wrapper.py +++ b/app/llm/wrapper/open_ai_embedding_wrapper.py @@ -6,11 +6,7 @@ ) -class OpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): - - def __init__(self, model: str, api_key: str): - self.client = OpenAI(api_key=api_key) - self.model = model +class BaseOpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface): def __init__(self, client, model: str): self.client = client @@ -24,11 +20,19 @@ def create_embedding(self, text: str) -> list[float]: ) return response.data[0].embedding + +class OpenAIEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): + + def __init__(self, model: str, api_key: str): + client = OpenAI(api_key=api_key) + model = model + super().__init__(client, model) + def __str__(self): return f"OpenAIEmbedding('{self.model}')" -class AzureEmbeddingWrapper(OpenAIEmbeddingWrapper): +class AzureEmbeddingWrapper(BaseOpenAIEmbeddingWrapper): def __init__( self,