Skip to content

Commit

Permalink
Add strategy llm
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Sep 24, 2023
1 parent 94d78e5 commit ea95e13
Show file tree
Hide file tree
Showing 12 changed files with 757 additions and 502 deletions.
29 changes: 27 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
import os

from pyaml_env import parse_config
from pydantic import BaseModel
from pydantic import BaseModel, validator


class LLMModelSpecs(BaseModel):
context_length: int


class LLMModelConfig(BaseModel):
type: str
name: str
description: str


class OpenAIConfig(LLMModelConfig):
spec: LLMModelSpecs
llm_credentials: dict

@validator("type")
def check_type(cls, v):
if v != "openai":
raise ValueError("Invalid type:" + v + " != openai")
return v


class StrategyLLMConfig(LLMModelConfig):
llms: list[str]

@validator("type")
def check_type(cls, v):
if v != "strategy":
raise ValueError("Invalid type:" + v + " != strategy")
return v


class APIKeyConfig(BaseModel):
token: str
Expand All @@ -27,7 +52,7 @@ class CacheParams(BaseModel):
class Settings(BaseModel):
class PyrisSettings(BaseModel):
api_keys: list[APIKeyConfig]
llms: dict[str, LLMModelConfig]
llms: dict[str, OpenAIConfig | StrategyLLMConfig]
cache: CacheSettings

pyris: PyrisSettings
Expand Down
Empty file added app/llms/__init__.py
Empty file.
175 changes: 175 additions & 0 deletions app/llms/strategy_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import logging
from typing import Dict, Any

from guidance.llms import LLM, OpenAI, LLMSession

from app.config import settings, OpenAIConfig
from app.services.cache import cache_store
from app.services.circuit_breaker import CircuitBreaker

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)


class StrategyLLM(LLM):
llm_keys: list[str]
llm_configs: dict[str, OpenAIConfig]
llm_sessions: dict[str, LLMSession]

def __init__(self, llm_keys: list[str]):
super().__init__()
if llm_keys.__len__() == 0:
raise ValueError("No LLMs configured")
self.llm_keys = [
llm_keys
for llm_keys in llm_keys
if llm_keys in settings.pyris.llms
and isinstance(settings.pyris.llms[llm_keys], OpenAIConfig)
]
self.llm_configs = {
llm_key: settings.pyris.llms[llm_key] for llm_key in self.llm_keys
}
self.llm_instances = {
llm_key: OpenAI(**llm_config.llm_credentials)
for llm_key, llm_config in self.llm_configs.items()
}
self.llm_sessions = {}

def first_llm(self):
return self.llm_instances[self.llm_keys[0]]

def extract_function_call(self, text):
return self.first_llm().extract_function_call(text)

def __call__(self, *args, asynchronous=False, **kwargs):
return self.first_llm()(*args, asynchronous=asynchronous, **kwargs)

def __getitem__(self, key):
return self.first_llm()[key]

def session(self, asynchronous=False):
return StrategyLLMSession(self)

def encode(self, string, **kwargs):
return self.first_llm().encode(string)

def decode(self, tokens, **kwargs):
return self.first_llm().decode(tokens)

def id_to_token(self, id):
return self.first_llm().id_to_token(id)

def token_to_id(self, token):
return self.first_llm().token_to_id(token)

def role_start(self, role_name, **kwargs):
return self.first_llm().role_start(role_name, **kwargs)

def role_end(self, role=None):
return self.first_llm().role_end(role)

def end_of_text(self):
return self.first_llm().end_of_text()


class StrategyLLMSession(LLMSession):
llm: StrategyLLM
current_session_key: str
current_session: LLMSession

def __init__(self, llm: StrategyLLM):
super().__init__(llm)
self.llm = llm

def get_total_context_length(
self, prompt: str, llm_key: str, max_tokens: int
):
prompt_token_length = (
self.llm.llm_instances[llm_key].encode(prompt).__len__()
)
return prompt_token_length + max_tokens

def get_llms_with_context_length(
self, prompt: str, max_tokens: int
) -> [str]:
return [
llm_key
for llm_key in self.llm.llm_keys
if self.llm.llm_configs[llm_key].spec.context_length
>= self.get_total_context_length(prompt, llm_key, max_tokens)
]

def get_available_llms(self, llms: [str]) -> [str]:
return [
llm_key
for llm_key in llms
if cache_store.get(llm_key + ":status") != "OPEN"
]

def set_correct_session(
self, prompt: str, max_tokens: int, exclude_llms=None
):
if exclude_llms is None:
exclude_llms = []
viable_llms = self.get_llms_with_context_length(prompt, max_tokens)
viable_llms = self.get_available_llms(viable_llms)
viable_llms = [
llm_key for llm_key in viable_llms if llm_key not in exclude_llms
]

if viable_llms.__len__() == 0:
log.error(
"No viable LLMs found! Using LLM with longest context length."
)
llm_configs = self.llm.llm_configs
selected_llm = max(
llm_configs,
key=lambda llm_key: llm_configs[llm_key].spec.context_length,
)
else:
selected_llm = viable_llms[0]
log.info("Selected LLM: " + selected_llm)

self.current_session_key = selected_llm
if selected_llm in self.llm.llm_sessions:
self.current_session = self.llm.llm_sessions[selected_llm]
else:
self.current_session = self.llm.llm_instances[
selected_llm
].session(asynchronous=True)
self.llm.llm_sessions[selected_llm] = self.current_session

async def __call__(self, *args, **kwargs):
prompt = args[0]
max_tokens = kwargs["max_tokens"]
self.set_correct_session(prompt, max_tokens)

def call():
return self.current_session(*args, **kwargs)

exclude_llms = []

try:
response = await CircuitBreaker.protected_call(
func=call, cache_key=self.current_session_key
)
except Exception as e:
log.error(
f"Exception {e} while making request! "
f"Trying again with a different LLM."
)
exclude_llms.append(self.current_session_key)
self.set_correct_session(prompt, max_tokens, exclude_llms)
response = await CircuitBreaker.protected_call(
func=call, cache_key=self.current_session_key
)
return response

def __exit__(self, exc_type, exc_value, traceback):
return self.current_session.__exit__(exc_type, exc_value, traceback)

def _gen_key(self, args_dict):
return self.current_session._gen_key(args_dict)

def _cache_params(self, args_dict) -> Dict[str, Any]:
return self.current_session._cache_params(args_dict)
10 changes: 8 additions & 2 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timezone

from fastapi import APIRouter, Depends
from guidance._program_executor import SyntaxException
from parsimonious.exceptions import IncompleteParseError

from app.core.custom_exceptions import (
Expand Down Expand Up @@ -37,11 +38,16 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse:
content = CircuitBreaker.protected_call(
func=guidance.query,
cache_key=body.preferred_model,
accepted_exceptions=(KeyError, SyntaxError, IncompleteParseError),
accepted_exceptions=(
KeyError,
SyntaxError,
SyntaxException,
IncompleteParseError,
),
)
except KeyError as e:
raise MissingParameterException(str(e))
except (SyntaxError, IncompleteParseError) as e:
except (SyntaxError, SyntaxException, IncompleteParseError) as e:
raise InvalidTemplateException(str(e))
except Exception as e:
raise InternalServerException(str(e))
Expand Down
15 changes: 12 additions & 3 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import cast

import guidance

from app.config import LLMModelConfig
from app.config import LLMModelConfig, OpenAIConfig, StrategyLLMConfig
from app.models.dtos import Content, ContentType
from app.services.guidance_functions import truncate
from app.llms.strategy_llm import StrategyLLM


class GuidanceWrapper:
Expand Down Expand Up @@ -69,5 +72,11 @@ def is_up(self) -> bool:
return content == "1"

def _get_llm(self):
llm_credentials = self.model.llm_credentials
return guidance.llms.OpenAI(**llm_credentials)
if isinstance(self.model, OpenAIConfig):
return guidance.llms.OpenAI(
**cast(OpenAIConfig, self.model).llm_credentials
)
elif isinstance(self.model, StrategyLLMConfig):
return StrategyLLM(cast(StrategyLLMConfig, self.model).llms)
else:
raise ValueError("Invalid model type: " + str(type(self.model)))
3 changes: 3 additions & 0 deletions application.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ pyris:

llms:
DUMMY:
type: "openai"
name: "Dummy model"
description: "Dummy model for testing"
spec:
context_length: 100
llm_credentials:
api_base: ""
api_type: ""
Expand Down
3 changes: 3 additions & 0 deletions application.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ pyris:

llms:
DUMMY:
type: "openai"
name: "Dummy model"
description: "Dummy model for testing"
spec:
context_length: 100
llm_credentials:
api_base: ""
api_type: ""
Expand Down
Loading

0 comments on commit ea95e13

Please sign in to comment.