Skip to content

Commit

Permalink
Some small performance tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Sep 26, 2023
1 parent 3cb0886 commit a54b90e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 40 deletions.
24 changes: 23 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

from guidance.llms import OpenAI
from pyaml_env import parse_config
from pydantic import BaseModel, validator
from pydantic import BaseModel, validator, Field, typing


class LLMModelSpecs(BaseModel):
Expand All @@ -13,27 +14,46 @@ class LLMModelConfig(BaseModel):
name: str
description: str

def get_instance(cls):
raise NotImplementedError()


class OpenAIConfig(LLMModelConfig):
spec: LLMModelSpecs
llm_credentials: dict
instance: typing.Any = Field(repr=False)

@validator("type")
def check_type(cls, v):
if v != "openai":
raise ValueError("Invalid type:" + v + " != openai")
return v

def get_instance(cls):
if cls.instance is not None:
return cls.instance
cls.instance = OpenAI(**cls.llm_credentials)
return cls.instance


class StrategyLLMConfig(LLMModelConfig):
llms: list[str]
instance: typing.Any = Field(repr=False)

@validator("type")
def check_type(cls, v):
if v != "strategy":
raise ValueError("Invalid type:" + v + " != strategy")
return v

def get_instance(cls):
if cls.instance is not None:
return cls.instance
# Local import needed to avoid circular dependency
from app.llms.strategy_llm import StrategyLLM
cls.instance = StrategyLLM(cls.llms)
return cls.instance


class APIKeyConfig(BaseModel):
token: str
Expand Down Expand Up @@ -69,3 +89,5 @@ def get_settings(cls):


settings = Settings.get_settings()
for value in enumerate(settings.pyris.llms.values()):
value[1].get_instance()
43 changes: 16 additions & 27 deletions app/llms/strategy_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Dict, Any

from guidance.llms import LLM, OpenAI, LLMSession
from guidance.llms import LLM, LLMSession

from app.config import settings, OpenAIConfig
from app.services.cache import cache_store
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(self, llm_keys: list[str]):
llm_key: settings.pyris.llms[llm_key] for llm_key in self.llm_keys
}
self.llm_instances = {
llm_key: OpenAI(**llm_config.llm_credentials)
llm_key: llm_config.get_instance()
for llm_key, llm_config in self.llm_configs.items()
}
self.llm_sessions = {}
Expand Down Expand Up @@ -89,35 +89,25 @@ def get_total_context_length(
)
return prompt_token_length + max_tokens

def get_llms_with_context_length(
self, prompt: str, max_tokens: int
) -> [str]:
return [
llm_key
for llm_key in self.llm.llm_keys
if self.llm.llm_configs[llm_key].spec.context_length
>= self.get_total_context_length(prompt, llm_key, max_tokens)
]

def get_available_llms(self, llms: [str]) -> [str]:
return [
llm_key
for llm_key in llms
if cache_store.get(llm_key + ":status") != "OPEN"
]

def set_correct_session(
self, prompt: str, max_tokens: int, exclude_llms=None
):
if exclude_llms is None:
exclude_llms = []
viable_llms = self.get_llms_with_context_length(prompt, max_tokens)
viable_llms = self.get_available_llms(viable_llms)
viable_llms = [
llm_key for llm_key in viable_llms if llm_key not in exclude_llms
]

if viable_llms.__len__() == 0:
selected_llm = None

for llm_key in self.llm.llm_keys:
if (llm_key not in exclude_llms
and cache_store.get(llm_key + ":status") != "OPEN"
and self.llm.llm_configs[llm_key].spec.context_length
>= self.get_total_context_length(prompt,
llm_key,
max_tokens)):
selected_llm = llm_key
break

if selected_llm is None:
log.error(
"No viable LLMs found! Using LLM with longest context length."
)
Expand All @@ -132,8 +122,7 @@ def set_correct_session(
llm_configs,
key=lambda llm_key: llm_configs[llm_key].spec.context_length,
)
else:
selected_llm = viable_llms[0]

log.info("Selected LLM: " + selected_llm)

self.current_session_key = selected_llm
Expand Down
14 changes: 2 additions & 12 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from typing import cast

import guidance

from app.config import LLMModelConfig, OpenAIConfig, StrategyLLMConfig
from app.config import LLMModelConfig
from app.models.dtos import Content, ContentType
from app.services.guidance_functions import truncate
from app.llms.strategy_llm import StrategyLLM


class GuidanceWrapper:
Expand Down Expand Up @@ -72,11 +69,4 @@ def is_up(self) -> bool:
return content == "1"

def _get_llm(self):
if isinstance(self.model, OpenAIConfig):
return guidance.llms.OpenAI(
**cast(OpenAIConfig, self.model).llm_credentials
)
elif isinstance(self.model, StrategyLLMConfig):
return StrategyLLM(cast(StrategyLLMConfig, self.model).llms)
else:
raise ValueError("Invalid model type: " + str(type(self.model)))
return self.model.get_instance()

0 comments on commit a54b90e

Please sign in to comment.