-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add IrisMessageRole and improve OpenAI wrappers
- Loading branch information
Showing
11 changed files
with
148 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
name: Pull Request Labeler | ||
on: [pull_request_target] | ||
on: pull_request_target | ||
|
||
jobs: | ||
label: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from domain.message import IrisMessage | ||
from domain.message import IrisMessage, IrisMessageRole |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,21 @@ | ||
from enum import Enum | ||
|
||
|
||
class IrisMessageRole(Enum): | ||
USER = "user" | ||
ASSISTANT = "assistant" | ||
SYSTEM = "system" | ||
|
||
|
||
class IrisMessage: | ||
def __init__(self, role, message_text): | ||
role: IrisMessageRole | ||
message_text: str | ||
|
||
def __init__(self, role: IrisMessageRole, message_text: str): | ||
self.role = role | ||
self.message_text = message_text | ||
|
||
def __str__(self): | ||
return ( | ||
f"IrisMessage(role={self.role.value}, message_text='{self.message_text}')" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
class CompletionArguments: | ||
"""Arguments for the completion request""" | ||
|
||
def __init__(self, max_tokens: int, temperature: float, stop: list[str]): | ||
def __init__( | ||
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None | ||
): | ||
self.max_tokens = max_tokens | ||
self.temperature = temperature | ||
self.stop = stop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,97 @@ | ||
import os | ||
|
||
import yaml | ||
|
||
from common import Singleton | ||
from llm.wrapper import LlmWrapperInterface | ||
|
||
|
||
def create_llm_wrapper(config: dict) -> LlmWrapperInterface: | ||
if config["type"] == "openai": | ||
from llm.wrapper import OpenAICompletionWrapper | ||
|
||
return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"]) | ||
elif config["type"] == "azure": | ||
from llm.wrapper import AzureCompletionWrapper | ||
|
||
return AzureCompletionWrapper( | ||
model=config["model"], | ||
endpoint=config["endpoint"], | ||
azure_deployment=config["azure_deployment"], | ||
api_version=config["api_version"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "openai_chat": | ||
from llm.wrapper import OpenAIChatCompletionWrapper | ||
|
||
return OpenAIChatCompletionWrapper( | ||
model=config["model"], api_key=config["api_key"] | ||
) | ||
elif config["type"] == "azure_chat": | ||
from llm.wrapper import AzureChatCompletionWrapper | ||
|
||
return AzureChatCompletionWrapper( | ||
model=config["model"], | ||
endpoint=config["endpoint"], | ||
azure_deployment=config["azure_deployment"], | ||
api_version=config["api_version"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "openai_embedding": | ||
from llm.wrapper import OpenAIEmbeddingWrapper | ||
|
||
return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"]) | ||
elif config["type"] == "azure_embedding": | ||
from llm.wrapper import AzureEmbeddingWrapper | ||
|
||
return AzureEmbeddingWrapper( | ||
model=config["model"], | ||
endpoint=config["endpoint"], | ||
azure_deployment=config["azure_deployment"], | ||
api_version=config["api_version"], | ||
api_key=config["api_key"], | ||
) | ||
elif config["type"] == "ollama": | ||
from llm.wrapper import OllamaWrapper | ||
|
||
return OllamaWrapper( | ||
model=config["model"], | ||
host=config["host"], | ||
) | ||
else: | ||
raise Exception(f"Unknown LLM type: {config['type']}") | ||
|
||
|
||
class LlmManagerEntry: | ||
id: str | ||
llm: LlmWrapperInterface | ||
|
||
def __init__(self, id: str, llm: LlmWrapperInterface): | ||
self.id = id | ||
self.llm = llm | ||
def __init__(self, config: dict): | ||
self.id = config["id"] | ||
self.llm = create_llm_wrapper(config) | ||
|
||
def __str__(self): | ||
return f"{self.id}: {self.llm}" | ||
|
||
|
||
class LlmManager(metaclass=Singleton): | ||
llms: list[LlmManagerEntry] | ||
|
||
def __init__(self): | ||
self.llms = [] | ||
self.load_llms() | ||
|
||
def get_llm_by_id(self, llm_id): | ||
for llm in self.llms: | ||
if llm.id == llm_id: | ||
return llm | ||
|
||
def add_llm(self, id: str, llm: LlmWrapperInterface): | ||
self.llms.append(LlmManagerEntry(id, llm)) | ||
def load_llms(self): | ||
path = os.environ.get("LLM_CONFIG_PATH") | ||
if not path: | ||
raise Exception("LLM_CONFIG_PATH not set") | ||
|
||
with open(path, "r") as file: | ||
loaded_llms = yaml.safe_load(file) | ||
|
||
self.llms = [LlmManagerEntry(llm) for llm in loaded_llms] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from llm.wrapper.llm_wrapper_interface import * | ||
from llm.wrapper.open_ai_completion_wrapper import * | ||
from llm.wrapper.open_ai_chat_wrapper import * | ||
from llm.wrapper.open_ai_embedding_wrapper import * | ||
from llm.wrapper.ollama_wrapper import OllamaWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters