Skip to content

Commit

Permalink
LLM: Use pydantic for config parsing (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus authored Feb 12, 2024
1 parent 94f786a commit fbb403a
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 180 deletions.
2 changes: 1 addition & 1 deletion app/llm/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from domain import IrisMessage
from llm import RequestHandlerInterface, CompletionArguments
from llm.llm_manager import LlmManager
from llm.wrapper import (
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmCompletionWrapper,
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
Expand Down
80 changes: 7 additions & 73 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,16 @@
import os

from pydantic import BaseModel, Field

import yaml

from common import Singleton
from llm.wrapper import AbstractLlmWrapper


# TODO: Replace with pydantic in a future PR
def create_llm_wrapper(config: dict) -> AbstractLlmWrapper:
if config["type"] == "openai":
from llm.wrapper import OpenAICompletionWrapper

return OpenAICompletionWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure":
from llm.wrapper import AzureCompletionWrapper

return AzureCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_chat":
from llm.wrapper import OpenAIChatCompletionWrapper

return OpenAIChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
api_key=config["api_key"],
)
elif config["type"] == "azure_chat":
from llm.wrapper import AzureChatCompletionWrapper

return AzureChatCompletionWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "openai_embedding":
from llm.wrapper import OpenAIEmbeddingWrapper

return OpenAIEmbeddingWrapper(model=config["model"], api_key=config["api_key"])
elif config["type"] == "azure_embedding":
from llm.wrapper import AzureEmbeddingWrapper
from llm.wrapper import AbstractLlmWrapper, LlmWrapper

return AzureEmbeddingWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
endpoint=config["endpoint"],
azure_deployment=config["azure_deployment"],
api_version=config["api_version"],
api_key=config["api_key"],
)
elif config["type"] == "ollama":
from llm.wrapper import OllamaWrapper

return OllamaWrapper(
id=config["id"],
name=config["name"],
description=config["description"],
model=config["model"],
host=config["host"],
)
else:
raise Exception(f"Unknown LLM type: {config['type']}")
# Small workaround to get pydantic discriminators working
class LlmList(BaseModel):
llms: list[LlmWrapper] = Field(discriminator="type")


class LlmManager(metaclass=Singleton):
Expand All @@ -99,4 +33,4 @@ def load_llms(self):
with open(path, "r") as file:
loaded_llms = yaml.safe_load(file)

self.entries = [create_llm_wrapper(llm) for llm in loaded_llms]
self.entries = LlmList.parse_obj({"llms": loaded_llms}).llms
27 changes: 23 additions & 4 deletions app/llm/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
from llm.wrapper.abstract_llm_wrapper import *
from llm.wrapper.open_ai_completion_wrapper import *
from llm.wrapper.open_ai_chat_wrapper import *
from llm.wrapper.open_ai_embedding_wrapper import *
from llm.wrapper.abstract_llm_wrapper import AbstractLlmWrapper
from llm.wrapper.open_ai_completion_wrapper import (
OpenAICompletionWrapper,
AzureCompletionWrapper,
)
from llm.wrapper.open_ai_chat_wrapper import (
OpenAIChatCompletionWrapper,
AzureChatCompletionWrapper,
)
from llm.wrapper.open_ai_embedding_wrapper import (
OpenAIEmbeddingWrapper,
AzureEmbeddingWrapper,
)
from llm.wrapper.ollama_wrapper import OllamaWrapper

type LlmWrapper = (
OpenAICompletionWrapper
| AzureCompletionWrapper
| OpenAIChatCompletionWrapper
| AzureChatCompletionWrapper
| OpenAIEmbeddingWrapper
| AzureEmbeddingWrapper
| OllamaWrapper
)
8 changes: 2 additions & 6 deletions app/llm/wrapper/abstract_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from domain import IrisMessage
from llm import CompletionArguments


class AbstractLlmWrapper(metaclass=ABCMeta):
class AbstractLlmWrapper(BaseModel, metaclass=ABCMeta):
"""Abstract class for the llm wrappers"""

id: str
name: str
description: str

def __init__(self, id: str, name: str, description: str):
self.id = id
self.name = name
self.description = description


class AbstractLlmCompletionWrapper(AbstractLlmWrapper, metaclass=ABCMeta):
"""Abstract class for the llm completion wrappers"""
Expand Down
20 changes: 12 additions & 8 deletions app/llm/wrapper/ollama_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Literal, Any

from ollama import Client, Message

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper import (
from llm.wrapper.abstract_llm_wrapper import (
AbstractLlmChatCompletionWrapper,
AbstractLlmCompletionWrapper,
AbstractLlmEmbeddingWrapper,
Expand All @@ -24,26 +26,28 @@ class OllamaWrapper(
AbstractLlmChatCompletionWrapper,
AbstractLlmEmbeddingWrapper,
):
type: Literal["ollama"]
model: str
host: str
_client: Client

def __init__(self, model: str, host: str, **kwargs):
super().__init__(**kwargs)
self.client = Client(host=host) # TODO: Add authentication (httpx auth?)
self.model = model
def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def completion(self, prompt: str, arguments: CompletionArguments) -> str:
response = self.client.generate(model=self.model, prompt=prompt)
response = self._client.generate(model=self.model, prompt=prompt)
return response["response"]

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> any:
response = self.client.chat(
response = self._client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
)
return convert_to_iris_message(response["message"])

def create_embedding(self, text: str) -> list[float]:
response = self.client.embeddings(model=self.model, prompt=text)
response = self._client.embeddings(model=self.model, prompt=text)
return list(response)

def __str__(self):
Expand Down
51 changes: 21 additions & 30 deletions app/llm/wrapper/open_ai_chat_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Literal, Any
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage

from domain import IrisMessage, IrisMessageRole
from llm import CompletionArguments
from llm.wrapper import AbstractLlmChatCompletionWrapper
from llm.wrapper.abstract_llm_wrapper import AbstractLlmChatCompletionWrapper


def convert_to_open_ai_messages(
Expand All @@ -21,16 +23,14 @@ def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage:


class BaseOpenAIChatCompletionWrapper(AbstractLlmChatCompletionWrapper):

def __init__(self, client, model: str, **kwargs):
super().__init__(**kwargs)
self.client = client
self.model = model
model: str
api_key: str
_client: OpenAI

def chat_completion(
self, messages: list[any], arguments: CompletionArguments
) -> IrisMessage:
response = self.client.chat.completions.create(
response = self._client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
temperature=arguments.temperature,
Expand All @@ -41,37 +41,28 @@ def chat_completion(


class OpenAIChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):
type: Literal["openai_chat"]

def __init__(self, model: str, api_key: str, **kwargs):
from openai import OpenAI

client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model, **kwargs)
def model_post_init(self, __context: Any) -> None:
self._client = OpenAI(api_key=self.api_key)

def __str__(self):
return f"OpenAIChat('{self.model}')"


class AzureChatCompletionWrapper(BaseOpenAIChatCompletionWrapper):

def __init__(
self,
model: str,
endpoint: str,
azure_deployment: str,
api_version: str,
api_key: str,
**kwargs,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
type: Literal["azure_chat"]
endpoint: str
azure_deployment: str
api_version: str

def model_post_init(self, __context: Any) -> None:
self._client = AzureOpenAI(
azure_endpoint=self.endpoint,
azure_deployment=self.azure_deployment,
api_version=self.api_version,
api_key=self.api_key,
)
model = model
super().__init__(client, model, **kwargs)

def __str__(self):
return f"AzureChat('{self.model}')"
48 changes: 20 additions & 28 deletions app/llm/wrapper/open_ai_completion_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from typing import Literal, Any
from openai import OpenAI
from openai.lib.azure import AzureOpenAI

from llm import CompletionArguments
from llm.wrapper import AbstractLlmCompletionWrapper
from llm.wrapper.abstract_llm_wrapper import AbstractLlmCompletionWrapper


class BaseOpenAICompletionWrapper(AbstractLlmCompletionWrapper):

def __init__(self, client, model: str, **kwargs):
super().__init__(**kwargs)
self.client = client
self.model = model
model: str
api_key: str
_client: OpenAI

def completion(self, prompt: str, arguments: CompletionArguments) -> any:
response = self.client.completions.create(
response = self._client.completions.create(
model=self.model,
prompt=prompt,
temperature=arguments.temperature,
Expand All @@ -24,35 +23,28 @@ def completion(self, prompt: str, arguments: CompletionArguments) -> any:


class OpenAICompletionWrapper(BaseOpenAICompletionWrapper):
type: Literal["openai_completion"]

def __init__(self, model: str, api_key: str, **kwargs):
client = OpenAI(api_key=api_key)
model = model
super().__init__(client, model, **kwargs)
def model_post_init(self, __context: Any) -> None:
self._client = OpenAI(api_key=self.api_key)

def __str__(self):
return f"OpenAICompletion('{self.model}')"


class AzureCompletionWrapper(BaseOpenAICompletionWrapper):

def __init__(
self,
model: str,
endpoint: str,
azure_deployment: str,
api_version: str,
api_key: str,
**kwargs,
):
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
type: Literal["azure_completion"]
endpoint: str
azure_deployment: str
api_version: str

def model_post_init(self, __context: Any) -> None:
self._client = AzureOpenAI(
azure_endpoint=self.endpoint,
azure_deployment=self.azure_deployment,
api_version=self.api_version,
api_key=self.api_key,
)
model = model
super().__init__(client, model, **kwargs)

def __str__(self):
return f"AzureCompletion('{self.model}')"
Loading

0 comments on commit fbb403a

Please sign in to comment.