From 7b01f18aa2b9e39b48dcbac8c71fa6157466bca9 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Fri, 3 May 2024 23:37:12 +0200 Subject: [PATCH] `Development`: Fix the LLM selection choosing invalid models --- app/llm/external/ollama.py | 15 ++++++++++++--- app/llm/llm_manager.py | 14 +++++++++++--- app/web/status/status_update.py | 3 +++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/app/llm/external/ollama.py b/app/llm/external/ollama.py index f2363b23..d6946e31 100644 --- a/app/llm/external/ollama.py +++ b/app/llm/external/ollama.py @@ -1,6 +1,7 @@ import base64 from datetime import datetime from typing import Literal, Any, Optional +from pydantic import Field from ollama import Client, Message @@ -76,6 +77,7 @@ class OllamaModel( type: Literal["ollama"] model: str host: str + options: dict[str, Any] = Field(default={}) _client: Client def model_post_init(self, __context: Any) -> None: @@ -88,7 +90,10 @@ def complete( image: Optional[ImageMessageContentDTO] = None, ) -> str: response = self._client.generate( - model=self.model, prompt=prompt, images=[image.base64] if image else None + model=self.model, + prompt=prompt, + images=[image.base64] if image else None, + options=self.options, ) return response["response"] @@ -96,12 +101,16 @@ def chat( self, messages: list[PyrisMessage], arguments: CompletionArguments ) -> PyrisMessage: response = self._client.chat( - model=self.model, messages=convert_to_ollama_messages(messages) + model=self.model, + messages=convert_to_ollama_messages(messages), + options=self.options, ) return convert_to_iris_message(response["message"]) def embed(self, text: str) -> list[float]: - response = self._client.embeddings(model=self.model, prompt=text) + response = self._client.embeddings( + model=self.model, prompt=text, options=self.options + ) return list(response) def __str__(self): diff --git a/app/llm/llm_manager.py b/app/llm/llm_manager.py index 0a112257..71879488 100644 --- a/app/llm/llm_manager.py +++ b/app/llm/llm_manager.py @@ -6,7 +6,10 @@ from ..common import Singleton from ..llm.capability import RequirementList -from ..llm.capability.capability_checker import calculate_capability_scores +from ..llm.capability.capability_checker import ( + calculate_capability_scores, + capabilities_fulfill_requirements, +) from ..llm.external import LanguageModel, AnyLLM @@ -41,9 +44,14 @@ def load_llms(self): def get_llms_sorted_by_capabilities_score( self, requirements: RequirementList, invert_cost: bool = False ): + valid_llms = [ + llm + for llm in self.entries + if capabilities_fulfill_requirements(llm.capabilities, requirements) + ] """Get the llms sorted by their capability to requirement scores""" scores = calculate_capability_scores( - [llm.capabilities for llm in self.entries], requirements, invert_cost + [llm.capabilities for llm in valid_llms], requirements, invert_cost ) - sorted_llms = sorted(zip(scores, self.entries), key=lambda pair: -pair[0]) + sorted_llms = sorted(zip(scores, valid_llms), key=lambda pair: -pair[0]) return [llm for _, llm in sorted_llms] diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index 2997409a..dbd7590c 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -128,6 +128,9 @@ def error(self, message: str): # Update the status after setting the stages to SKIPPED self.stage = self.status.stages[-1] self.on_status_update() + logger.error( + f"Error occurred in job {self.run_id} in stage {self.stage.name}: {message}" + ) def skip(self, message: Optional[str] = None): """