diff --git a/app/config.py b/app/config.py index ac1880a8..c40c5781 100644 --- a/app/config.py +++ b/app/config.py @@ -1,7 +1,8 @@ import os +from guidance.llms import OpenAI from pyaml_env import parse_config -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, Field, typing class LLMModelSpecs(BaseModel): @@ -13,10 +14,14 @@ class LLMModelConfig(BaseModel): name: str description: str + def get_instance(cls): + raise NotImplementedError() + class OpenAIConfig(LLMModelConfig): spec: LLMModelSpecs llm_credentials: dict + instance: typing.Any = Field(repr=False) @validator("type") def check_type(cls, v): @@ -24,9 +29,16 @@ def check_type(cls, v): raise ValueError("Invalid type:" + v + " != openai") return v + def get_instance(cls): + if cls.instance is not None: + return cls.instance + cls.instance = OpenAI(**cls.llm_credentials) + return cls.instance + class StrategyLLMConfig(LLMModelConfig): llms: list[str] + instance: typing.Any = Field(repr=False) @validator("type") def check_type(cls, v): @@ -34,6 +46,14 @@ def check_type(cls, v): raise ValueError("Invalid type:" + v + " != strategy") return v + def get_instance(cls): + if cls.instance is not None: + return cls.instance + # Local import needed to avoid circular dependency + from app.llms.strategy_llm import StrategyLLM + cls.instance = StrategyLLM(cls.llms) + return cls.instance + class APIKeyConfig(BaseModel): token: str @@ -69,3 +89,5 @@ def get_settings(cls): settings = Settings.get_settings() +for value in enumerate(settings.pyris.llms.values()): + value[1].get_instance() diff --git a/app/llms/strategy_llm.py b/app/llms/strategy_llm.py index 87e87384..35c25e4e 100644 --- a/app/llms/strategy_llm.py +++ b/app/llms/strategy_llm.py @@ -1,7 +1,7 @@ import logging from typing import Dict, Any -from guidance.llms import LLM, OpenAI, LLMSession +from guidance.llms import LLM, LLMSession from app.config import settings, OpenAIConfig from app.services.cache import cache_store @@ -30,7 +30,7 @@ def __init__(self, llm_keys: list[str]): llm_key: settings.pyris.llms[llm_key] for llm_key in self.llm_keys } self.llm_instances = { - llm_key: OpenAI(**llm_config.llm_credentials) + llm_key: llm_config.get_instance() for llm_key, llm_config in self.llm_configs.items() } self.llm_sessions = {} @@ -89,35 +89,25 @@ def get_total_context_length( ) return prompt_token_length + max_tokens - def get_llms_with_context_length( - self, prompt: str, max_tokens: int - ) -> [str]: - return [ - llm_key - for llm_key in self.llm.llm_keys - if self.llm.llm_configs[llm_key].spec.context_length - >= self.get_total_context_length(prompt, llm_key, max_tokens) - ] - - def get_available_llms(self, llms: [str]) -> [str]: - return [ - llm_key - for llm_key in llms - if cache_store.get(llm_key + ":status") != "OPEN" - ] - def set_correct_session( self, prompt: str, max_tokens: int, exclude_llms=None ): if exclude_llms is None: exclude_llms = [] - viable_llms = self.get_llms_with_context_length(prompt, max_tokens) - viable_llms = self.get_available_llms(viable_llms) - viable_llms = [ - llm_key for llm_key in viable_llms if llm_key not in exclude_llms - ] - if viable_llms.__len__() == 0: + selected_llm = None + + for llm_key in self.llm.llm_keys: + if (llm_key not in exclude_llms + and cache_store.get(llm_key + ":status") != "OPEN" + and self.llm.llm_configs[llm_key].spec.context_length + >= self.get_total_context_length(prompt, + llm_key, + max_tokens)): + selected_llm = llm_key + break + + if selected_llm is None: log.error( "No viable LLMs found! Using LLM with longest context length." ) @@ -132,8 +122,7 @@ def set_correct_session( llm_configs, key=lambda llm_key: llm_configs[llm_key].spec.context_length, ) - else: - selected_llm = viable_llms[0] + log.info("Selected LLM: " + selected_llm) self.current_session_key = selected_llm diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 803387d3..34b88802 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -1,11 +1,8 @@ -from typing import cast - import guidance -from app.config import LLMModelConfig, OpenAIConfig, StrategyLLMConfig +from app.config import LLMModelConfig from app.models.dtos import Content, ContentType from app.services.guidance_functions import truncate -from app.llms.strategy_llm import StrategyLLM class GuidanceWrapper: @@ -72,11 +69,4 @@ def is_up(self) -> bool: return content == "1" def _get_llm(self): - if isinstance(self.model, OpenAIConfig): - return guidance.llms.OpenAI( - **cast(OpenAIConfig, self.model).llm_credentials - ) - elif isinstance(self.model, StrategyLLMConfig): - return StrategyLLM(cast(StrategyLLMConfig, self.model).llms) - else: - raise ValueError("Invalid model type: " + str(type(self.model))) + return self.model.get_instance()