Skip to content

Commit

Permalink
Small refinements
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Feb 10, 2024
1 parent 898d1fc commit 14ae279
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 46 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ exclude =
per-file-ignores =
# imported but unused
__init__.py: F401, F403
open_ai_chat_wrapper.py: F811
open_ai_completion_wrapper.py: F811
open_ai_embedding_wrapper.py: F811

2 changes: 1 addition & 1 deletion app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from singleton import Singleton
from common.singleton import Singleton
2 changes: 1 addition & 1 deletion app/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from message import IrisMessage
from domain.message import IrisMessage
8 changes: 4 additions & 4 deletions app/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from generation_arguments import CompletionArguments
from request_handler_interface import RequestHandlerInterface
from basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
from llm_manager import LlmManager
from llm.generation_arguments import CompletionArguments
from llm.request_handler_interface import RequestHandlerInterface
from llm.llm_manager import LlmManager
from llm.basic_request_handler import BasicRequestHandler, BasicRequestHandlerModel
12 changes: 9 additions & 3 deletions app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> str:
if isinstance(llm, LlmCompletionWrapperInterface):
return llm.completion(prompt, arguments)
else:
raise NotImplementedError
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support completion"
)

def chat_completion(
self, messages: list[IrisMessage], arguments: CompletionArguments
Expand All @@ -32,11 +34,15 @@ def chat_completion(
if isinstance(llm, LlmChatCompletionWrapperInterface):
return llm.chat_completion(messages, arguments)
else:
raise NotImplementedError
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support chat completion"
)

def create_embedding(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model).llm
if isinstance(llm, LlmEmbeddingWrapperInterface):
return llm.create_embedding(text)
else:
raise NotImplementedError
raise NotImplementedError(
f"The LLM {llm.__str__()} does not support embedding"
)
6 changes: 3 additions & 3 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from llm_wrapper_interface import *
from open_ai_chat_wrapper import *
from ollama_wrapper import OllamaWrapper
from llm.wrapper.llm_wrapper_interface import *
from llm.wrapper.open_ai_chat_wrapper import *
from llm.wrapper.ollama_wrapper import OllamaWrapper
32 changes: 0 additions & 32 deletions app/llm/wrapper/azure_chat_wrapper.py

This file was deleted.

7 changes: 5 additions & 2 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]:


def convert_to_iris_message(message: Message) -> IrisMessage:
return IrisMessage(role=message.role, message_text=message.content)
return IrisMessage(role=message["role"], message_text=message["content"])


class OllamaWrapper(
Expand All @@ -43,4 +43,7 @@ def chat_completion(

def create_embedding(self, text: str) -> list[float]:
response = self.client.embeddings(model=self.model, prompt=text)
return response
return list(response)

def __str__(self):
return f"Ollama('{self.model}')"
34 changes: 34 additions & 0 deletions app/llm/wrapper/open_ai_chat_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam

from domain import IrisMessage
Expand All @@ -25,11 +26,44 @@ def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model

def __init__(self, client, model: str):
self.client = client
self.model = model

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
response = self.client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
stop=arguments.stop,
)
return response

def __str__(self):
return f"OpenAIChat('{self.model}')"


class AzureChatCompletionWrapper(OpenAIChatCompletionWrapper):

def __init__(
self,
model: str,
endpoint: str,
azure_deployment: str,
api_version: str,
api_key: str,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
)
model = model
super().__init__(client, model)

def __str__(self):
return f"AzureChat('{self.model}')"
52 changes: 52 additions & 0 deletions app/llm/wrapper/open_ai_completion_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI

from llm import CompletionArguments
from llm.wrapper import LlmCompletionWrapperInterface


class OpenAICompletionWrapper(LlmCompletionWrapperInterface):

def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model

def __init__(self, client, model: str):
self.client = client
self.model = model

def completion(self, prompt: str, arguments: CompletionArguments) -> any:
response = self.client.completions.create(
model=self.model,
prompt=prompt,
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
stop=arguments.stop,
)
return response

def __str__(self):
return f"OpenAICompletion('{self.model}')"


class AzureCompletionWrapper(OpenAICompletionWrapper):

def __init__(
self,
model: str,
endpoint: str,
azure_deployment: str,
api_version: str,
api_key: str,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
)
model = model
super().__init__(client, model)

def __str__(self):
return f"AzureCompletion('{self.model}')"
51 changes: 51 additions & 0 deletions app/llm/wrapper/open_ai_embedding_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI

from llm.wrapper import (
LlmEmbeddingWrapperInterface,
)


class OpenAIEmbeddingWrapper(LlmEmbeddingWrapperInterface):

def __init__(self, model: str, api_key: str):
self.client = OpenAI(api_key=api_key)
self.model = model

def __init__(self, client, model: str):
self.client = client
self.model = model

def create_embedding(self, text: str) -> list[float]:
response = self.client.embeddings.create(
model=self.model,
input=text,
encoding_format="float",
)
return response.data[0].embedding

def __str__(self):
return f"OpenAIEmbedding('{self.model}')"


class AzureEmbeddingWrapper(OpenAIEmbeddingWrapper):

def __init__(
self,
model: str,
endpoint: str,
azure_deployment: str,
api_version: str,
api_key: str,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
)
model = model
super().__init__(client, model)

def __str__(self):
return f"AzureEmbedding('{self.model}')"

0 comments on commit 14ae279

Please sign in to comment.