-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
757 additions
and
502 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import logging | ||
from typing import Dict, Any | ||
|
||
from guidance.llms import LLM, OpenAI, LLMSession | ||
|
||
from app.config import settings, OpenAIConfig | ||
from app.services.cache import cache_store | ||
from app.services.circuit_breaker import CircuitBreaker | ||
|
||
log = logging.getLogger(__name__) | ||
log.setLevel(logging.INFO) | ||
|
||
|
||
class StrategyLLM(LLM): | ||
llm_keys: list[str] | ||
llm_configs: dict[str, OpenAIConfig] | ||
llm_sessions: dict[str, LLMSession] | ||
|
||
def __init__(self, llm_keys: list[str]): | ||
super().__init__() | ||
if llm_keys.__len__() == 0: | ||
raise ValueError("No LLMs configured") | ||
self.llm_keys = [ | ||
llm_keys | ||
for llm_keys in llm_keys | ||
if llm_keys in settings.pyris.llms | ||
and isinstance(settings.pyris.llms[llm_keys], OpenAIConfig) | ||
] | ||
self.llm_configs = { | ||
llm_key: settings.pyris.llms[llm_key] for llm_key in self.llm_keys | ||
} | ||
self.llm_instances = { | ||
llm_key: OpenAI(**llm_config.llm_credentials) | ||
for llm_key, llm_config in self.llm_configs.items() | ||
} | ||
self.llm_sessions = {} | ||
|
||
def first_llm(self): | ||
return self.llm_instances[self.llm_keys[0]] | ||
|
||
def extract_function_call(self, text): | ||
return self.first_llm().extract_function_call(text) | ||
|
||
def __call__(self, *args, asynchronous=False, **kwargs): | ||
return self.first_llm()(*args, asynchronous=asynchronous, **kwargs) | ||
|
||
def __getitem__(self, key): | ||
return self.first_llm()[key] | ||
|
||
def session(self, asynchronous=False): | ||
return StrategyLLMSession(self) | ||
|
||
def encode(self, string, **kwargs): | ||
return self.first_llm().encode(string) | ||
|
||
def decode(self, tokens, **kwargs): | ||
return self.first_llm().decode(tokens) | ||
|
||
def id_to_token(self, id): | ||
return self.first_llm().id_to_token(id) | ||
|
||
def token_to_id(self, token): | ||
return self.first_llm().token_to_id(token) | ||
|
||
def role_start(self, role_name, **kwargs): | ||
return self.first_llm().role_start(role_name, **kwargs) | ||
|
||
def role_end(self, role=None): | ||
return self.first_llm().role_end(role) | ||
|
||
def end_of_text(self): | ||
return self.first_llm().end_of_text() | ||
|
||
|
||
class StrategyLLMSession(LLMSession): | ||
llm: StrategyLLM | ||
current_session_key: str | ||
current_session: LLMSession | ||
|
||
def __init__(self, llm: StrategyLLM): | ||
super().__init__(llm) | ||
self.llm = llm | ||
|
||
def get_total_context_length( | ||
self, prompt: str, llm_key: str, max_tokens: int | ||
): | ||
prompt_token_length = ( | ||
self.llm.llm_instances[llm_key].encode(prompt).__len__() | ||
) | ||
return prompt_token_length + max_tokens | ||
|
||
def get_llms_with_context_length( | ||
self, prompt: str, max_tokens: int | ||
) -> [str]: | ||
return [ | ||
llm_key | ||
for llm_key in self.llm.llm_keys | ||
if self.llm.llm_configs[llm_key].spec.context_length | ||
>= self.get_total_context_length(prompt, llm_key, max_tokens) | ||
] | ||
|
||
def get_available_llms(self, llms: [str]) -> [str]: | ||
return [ | ||
llm_key | ||
for llm_key in llms | ||
if cache_store.get(llm_key + ":status") != "OPEN" | ||
] | ||
|
||
def set_correct_session( | ||
self, prompt: str, max_tokens: int, exclude_llms=None | ||
): | ||
if exclude_llms is None: | ||
exclude_llms = [] | ||
viable_llms = self.get_llms_with_context_length(prompt, max_tokens) | ||
viable_llms = self.get_available_llms(viable_llms) | ||
viable_llms = [ | ||
llm_key for llm_key in viable_llms if llm_key not in exclude_llms | ||
] | ||
|
||
if viable_llms.__len__() == 0: | ||
log.error( | ||
"No viable LLMs found! Using LLM with longest context length." | ||
) | ||
llm_configs = self.llm.llm_configs | ||
selected_llm = max( | ||
llm_configs, | ||
key=lambda llm_key: llm_configs[llm_key].spec.context_length, | ||
) | ||
else: | ||
selected_llm = viable_llms[0] | ||
log.info("Selected LLM: " + selected_llm) | ||
|
||
self.current_session_key = selected_llm | ||
if selected_llm in self.llm.llm_sessions: | ||
self.current_session = self.llm.llm_sessions[selected_llm] | ||
else: | ||
self.current_session = self.llm.llm_instances[ | ||
selected_llm | ||
].session(asynchronous=True) | ||
self.llm.llm_sessions[selected_llm] = self.current_session | ||
|
||
async def __call__(self, *args, **kwargs): | ||
prompt = args[0] | ||
max_tokens = kwargs["max_tokens"] | ||
self.set_correct_session(prompt, max_tokens) | ||
|
||
def call(): | ||
return self.current_session(*args, **kwargs) | ||
|
||
exclude_llms = [] | ||
|
||
try: | ||
response = await CircuitBreaker.protected_call( | ||
func=call, cache_key=self.current_session_key | ||
) | ||
except Exception as e: | ||
log.error( | ||
f"Exception {e} while making request! " | ||
f"Trying again with a different LLM." | ||
) | ||
exclude_llms.append(self.current_session_key) | ||
self.set_correct_session(prompt, max_tokens, exclude_llms) | ||
response = await CircuitBreaker.protected_call( | ||
func=call, cache_key=self.current_session_key | ||
) | ||
return response | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
return self.current_session.__exit__(exc_type, exc_value, traceback) | ||
|
||
def _gen_key(self, args_dict): | ||
return self.current_session._gen_key(args_dict) | ||
|
||
def _cache_params(self, args_dict) -> Dict[str, Any]: | ||
return self.current_session._cache_params(args_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.