Skip to content

Commit

Permalink
Development: Fix the LLM selection choosing invalid models
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed May 3, 2024
1 parent fb2b63b commit 7b01f18
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
15 changes: 12 additions & 3 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
from datetime import datetime
from typing import Literal, Any, Optional
from pydantic import Field

from ollama import Client, Message

Expand Down Expand Up @@ -76,6 +77,7 @@ class OllamaModel(
type: Literal["ollama"]
model: str
host: str
options: dict[str, Any] = Field(default={})
_client: Client

def model_post_init(self, __context: Any) -> None:
Expand All @@ -88,20 +90,27 @@ def complete(
image: Optional[ImageMessageContentDTO] = None,
) -> str:
response = self._client.generate(
model=self.model, prompt=prompt, images=[image.base64] if image else None
model=self.model,
prompt=prompt,
images=[image.base64] if image else None,
options=self.options,
)
return response["response"]

def chat(
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
response = self._client.chat(
model=self.model, messages=convert_to_ollama_messages(messages)
model=self.model,
messages=convert_to_ollama_messages(messages),
options=self.options,
)
return convert_to_iris_message(response["message"])

def embed(self, text: str) -> list[float]:
response = self._client.embeddings(model=self.model, prompt=text)
response = self._client.embeddings(
model=self.model, prompt=text, options=self.options
)
return list(response)

def __str__(self):
Expand Down
14 changes: 11 additions & 3 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from ..common import Singleton
from ..llm.capability import RequirementList
from ..llm.capability.capability_checker import calculate_capability_scores
from ..llm.capability.capability_checker import (
calculate_capability_scores,
capabilities_fulfill_requirements,
)
from ..llm.external import LanguageModel, AnyLLM


Expand Down Expand Up @@ -41,9 +44,14 @@ def load_llms(self):
def get_llms_sorted_by_capabilities_score(
self, requirements: RequirementList, invert_cost: bool = False
):
valid_llms = [
llm
for llm in self.entries
if capabilities_fulfill_requirements(llm.capabilities, requirements)
]
"""Get the llms sorted by their capability to requirement scores"""
scores = calculate_capability_scores(
[llm.capabilities for llm in self.entries], requirements, invert_cost
[llm.capabilities for llm in valid_llms], requirements, invert_cost
)
sorted_llms = sorted(zip(scores, self.entries), key=lambda pair: -pair[0])
sorted_llms = sorted(zip(scores, valid_llms), key=lambda pair: -pair[0])
return [llm for _, llm in sorted_llms]
3 changes: 3 additions & 0 deletions app/web/status/status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def error(self, message: str):
# Update the status after setting the stages to SKIPPED
self.stage = self.status.stages[-1]
self.on_status_update()
logger.error(
f"Error occurred in job {self.run_id} in stage {self.stage.name}: {message}"
)

def skip(self, message: Optional[str] = None):
"""
Expand Down

0 comments on commit 7b01f18

Please sign in to comment.