-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
311 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from llm.completion_arguments import * | ||
from llm.external import * | ||
from llm.request_handler import * | ||
from llm.capability import RequirementList |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llm.capability.capability_list import CapabilityList | ||
from llm.capability.requirement_list import RequirementList | ||
from llm.capability.capability_checker import capabilities_fulfill_requirements |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from .capability_list import ( | ||
CapabilityList, | ||
capability_weights, | ||
always_considered_capabilities_with_default, | ||
) | ||
from .requirement_list import RequirementList | ||
|
||
|
||
def capabilities_fulfill_requirements( | ||
capability: CapabilityList, requirements: RequirementList | ||
) -> bool: | ||
"""Check if the capability fulfills the requirements""" | ||
return all( | ||
getattr(capability, field).matches(getattr(requirements, field)) | ||
for field in requirements.__dict__.keys() | ||
if getattr(requirements, field) is not None | ||
) | ||
|
||
|
||
def calculate_capability_scores( | ||
capabilities: list[CapabilityList], | ||
requirements: RequirementList, | ||
invert_cost: bool = False, | ||
) -> list[int]: | ||
"""Calculate the scores of the capabilities against the requirements""" | ||
all_scores = [] | ||
|
||
for requirement in requirements.__dict__.keys(): | ||
requirement_value = getattr(requirements, requirement) | ||
if ( | ||
requirement_value is None | ||
and requirement not in always_considered_capabilities_with_default | ||
): | ||
continue | ||
|
||
# Calculate the scores for each capability | ||
scores = [] | ||
for capability in capabilities: | ||
if ( | ||
requirement_value is None | ||
and requirement in always_considered_capabilities_with_default | ||
): | ||
# If the requirement is not set, use the default value if necessary | ||
score = getattr(capability, requirement).matches( | ||
always_considered_capabilities_with_default[requirement] | ||
) | ||
else: | ||
score = getattr(capability, requirement).matches(requirement_value) | ||
# Invert the cost if required | ||
# The cost is a special case, as depending on how you want to use the scores | ||
# the cost needs to be considered differently | ||
if requirement == "cost" and invert_cost and score != 0: | ||
score = 1 / score | ||
scores.append(score) | ||
|
||
# Normalize the scores between 0 and 1 and multiply by the weight modifier | ||
# The normalization here is based on the position of the score in the sorted list to balance out | ||
# the different ranges of the capabilities | ||
sorted_scores = sorted(set(scores)) | ||
weight_modifier = capability_weights[requirement] | ||
normalized_scores = [ | ||
((sorted_scores.index(score) + 1) / len(sorted_scores)) * weight_modifier | ||
for score in scores | ||
] | ||
all_scores.append(normalized_scores) | ||
|
||
final_scores = [] | ||
|
||
# Sum up the scores for each capability to get the final score for each list of capabilities | ||
for i in range(len(all_scores[0])): | ||
score = 0 | ||
for j in range(len(all_scores)): | ||
score += all_scores[j][i] | ||
final_scores.append(score) | ||
|
||
return final_scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from abc import ABCMeta | ||
from pydantic import BaseModel, Field, model_validator | ||
|
||
|
||
class Capability(metaclass=ABCMeta): | ||
"""A capability to match a generic value""" | ||
|
||
@classmethod | ||
def __subclasshook__(cls, subclass) -> bool: | ||
return hasattr(subclass, "matches") and callable(subclass.matches) | ||
|
||
def matches(self, other: any) -> int: | ||
"""Return a score for how well the capability matches the input""" | ||
raise NotImplementedError | ||
|
||
|
||
class TextCapability(BaseModel): | ||
"""A capability to match a fixed text value""" | ||
|
||
value: str | ||
|
||
def matches(self, text: str) -> int: | ||
return int(self.value == text) | ||
|
||
def __str__(self): | ||
return f"TextCapability({super().__str__()})" | ||
|
||
|
||
class OrderedNumberCapability(BaseModel): | ||
"""A capability that is better the higher the value""" | ||
|
||
value: int | float | ||
|
||
def matches(self, number: int | float) -> int | float: | ||
if self.value < number: | ||
return 0 | ||
return self.value - number + 1 | ||
|
||
def __str__(self): | ||
return f"OrderedNumberCapability({super().__str__()})" | ||
|
||
|
||
class InverseOrderedNumberCapability(BaseModel): | ||
"""A capability that is better the lower the value""" | ||
|
||
value: int | float | ||
|
||
def matches(self, number: int | float) -> int | float: | ||
if self.value > number: | ||
return 0 | ||
return number - self.value + 1 | ||
|
||
def __str__(self): | ||
return f"InverseOrderedNumberCapability({super().__str__()})" | ||
|
||
|
||
class BooleanCapability(BaseModel): | ||
"""A simple boolean capability""" | ||
|
||
value: bool | ||
|
||
def matches(self, boolean: bool) -> int: | ||
return int(self.value == boolean) | ||
|
||
def __str__(self): | ||
return f"BooleanCapability({str(self.value)})" | ||
|
||
|
||
class CapabilityList(BaseModel): | ||
"""A list of capabilities for a model""" | ||
|
||
cost: InverseOrderedNumberCapability = Field( | ||
default=InverseOrderedNumberCapability(value=0) | ||
) | ||
gpt_version_equivalent: OrderedNumberCapability = Field( | ||
default=OrderedNumberCapability(value=2) | ||
) | ||
context_length: OrderedNumberCapability = Field( | ||
default=OrderedNumberCapability(value=0) | ||
) | ||
vendor: TextCapability = Field(default=TextCapability(value="")) | ||
privacy_compliance: BooleanCapability = Field( | ||
default=BooleanCapability(value=False) | ||
) | ||
self_hosted: BooleanCapability = Field(default=BooleanCapability(value=False)) | ||
image_recognition: BooleanCapability = Field(default=BooleanCapability(value=False)) | ||
json_mode: BooleanCapability = Field(default=BooleanCapability(value=False)) | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def from_dict(cls, data: dict[str, any]): | ||
"""Prepare the data for handling by Pydantic""" | ||
for key, value in data.items(): | ||
if type(value) is not dict: | ||
data[key] = {"value": value} | ||
return data | ||
|
||
|
||
# The weights for the capabilities used in the scoring | ||
capability_weights = { | ||
"cost": 1, | ||
"gpt_version_equivalent": 4, | ||
"context_length": 0.1, | ||
"vendor": 1, | ||
"privacy_compliance": 0, | ||
"self_hosted": 0, | ||
"image_recognition": 0, | ||
"json_mode": 0, | ||
} | ||
|
||
# The default values for the capabilities that are always considered | ||
always_considered_capabilities_with_default = {"cost": 100000000000000} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
class RequirementList: | ||
"""A class to represent the requirements you want to match against""" | ||
|
||
cost: int | None | ||
gpt_version_equivalent: float | None | ||
context_length: int | None | ||
vendor: str | None | ||
privacy_compliance: bool | None | ||
self_hosted: bool | None | ||
image_recognition: bool | None | ||
json_mode: bool | None | ||
|
||
def __init__( | ||
self, | ||
cost: int | None = None, | ||
gpt_version_equivalent: float | None = None, | ||
context_length: int | None = None, | ||
vendor: str | None = None, | ||
privacy_compliance: bool | None = None, | ||
self_hosted: bool | None = None, | ||
image_recognition: bool | None = None, | ||
json_mode: bool | None = None, | ||
) -> None: | ||
self.cost = cost | ||
self.gpt_version_equivalent = gpt_version_equivalent | ||
self.context_length = context_length | ||
self.vendor = vendor | ||
self.privacy_compliance = privacy_compliance | ||
self.self_hosted = self_hosted | ||
self.image_recognition = image_recognition | ||
self.json_mode = json_mode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
from basic_request_handler import BasicRequestHandler | ||
from request_handler_interface import RequestHandler | ||
from llm.request_handler.request_handler_interface import RequestHandler | ||
from llm.request_handler.basic_request_handler import BasicRequestHandler | ||
from llm.request_handler.capability_request_handler import ( | ||
CapabilityRequestHandler, | ||
CapabilityRequestHandlerSelectionMode, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from enum import Enum | ||
|
||
from domain import IrisMessage | ||
from llm.capability import RequirementList | ||
from llm.external.model import ChatModel, CompletionModel, EmbeddingModel, LanguageModel | ||
from llm.request_handler import RequestHandler | ||
from llm.completion_arguments import CompletionArguments | ||
from llm.llm_manager import LlmManager | ||
|
||
|
||
class CapabilityRequestHandlerSelectionMode(Enum): | ||
"""Enum for the selection mode of the capability request handler""" | ||
|
||
BEST = "best" | ||
WORST = "worst" | ||
|
||
|
||
class CapabilityRequestHandler(RequestHandler): | ||
"""Request handler that selects the best/worst model based on the requirements""" | ||
|
||
requirements: RequirementList | ||
selection_mode: CapabilityRequestHandlerSelectionMode | ||
llm_manager: LlmManager | ||
|
||
def __init__( | ||
self, | ||
requirements: RequirementList, | ||
selection_mode: CapabilityRequestHandlerSelectionMode = CapabilityRequestHandlerSelectionMode.WORST, | ||
) -> None: | ||
self.requirements = requirements | ||
self.selection_mode = selection_mode | ||
self.llm_manager = LlmManager() | ||
|
||
def complete(self, prompt: str, arguments: CompletionArguments) -> str: | ||
llm = self._select_model(CompletionModel) | ||
return llm.complete(prompt, arguments) | ||
|
||
def chat( | ||
self, messages: list[IrisMessage], arguments: CompletionArguments | ||
) -> IrisMessage: | ||
llm = self._select_model(ChatModel) | ||
return llm.chat(messages, arguments) | ||
|
||
def embed(self, text: str) -> list[float]: | ||
llm = self._select_model(EmbeddingModel) | ||
return llm.embed(text) | ||
|
||
def _select_model(self, type_filter: type) -> LanguageModel: | ||
"""Select the best/worst model based on the requirements and the selection mode""" | ||
llms = self.llm_manager.get_llms_sorted_by_capabilities_score( | ||
self.requirements, | ||
self.selection_mode == CapabilityRequestHandlerSelectionMode.WORST, | ||
) | ||
llms = [llm for llm in llms if isinstance(llm, type_filter)] | ||
|
||
if self.selection_mode == CapabilityRequestHandlerSelectionMode.BEST: | ||
llm = llms[0] | ||
else: | ||
llm = llms[-1] | ||
|
||
# Print the selected model for the logs | ||
print(f"Selected {llm.description}") | ||
return llm |