Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLMs: Add strategy LLM #16

Merged
merged 16 commits into from
Oct 6, 2023
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ RUN poetry install
COPY app /app/app

EXPOSE 8000
CMD ["poetry", "run", "gunicorn", "app.main:app", "-w", "8", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]
CMD ["poetry", "run", "gunicorn", "app.main:app", "--workers", "8", "--threads", "8", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]
54 changes: 52 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,59 @@
import os

from guidance.llms import OpenAI
from pyaml_env import parse_config
from pydantic import BaseModel
from pydantic import BaseModel, validator, Field, typing


class LLMModelSpecs(BaseModel):
context_length: int


class LLMModelConfig(BaseModel):
type: str
name: str
description: str

def get_instance(cls):
raise NotImplementedError()


class OpenAIConfig(LLMModelConfig):
spec: LLMModelSpecs
llm_credentials: dict
instance: typing.Any = Field(repr=False)

@validator("type")
def check_type(cls, v):
if v != "openai":
raise ValueError("Invalid type:" + v + " != openai")
return v

def get_instance(cls):
if cls.instance is not None:
return cls.instance
cls.instance = OpenAI(**cls.llm_credentials)
return cls.instance


class StrategyLLMConfig(LLMModelConfig):
llms: list[str]
instance: typing.Any = Field(repr=False)

@validator("type")
def check_type(cls, v):
if v != "strategy":
raise ValueError("Invalid type:" + v + " != strategy")
return v

def get_instance(cls):
if cls.instance is not None:
return cls.instance
# Local import needed to avoid circular dependency
from app.llms.strategy_llm import StrategyLLM

cls.instance = StrategyLLM(cls.llms)
return cls.instance


class APIKeyConfig(BaseModel):
Expand All @@ -27,7 +73,7 @@ class CacheParams(BaseModel):
class Settings(BaseModel):
class PyrisSettings(BaseModel):
api_keys: list[APIKeyConfig]
llms: dict[str, LLMModelConfig]
llms: dict[str, OpenAIConfig | StrategyLLMConfig]
cache: CacheSettings

pyris: PyrisSettings
Expand All @@ -44,3 +90,7 @@ def get_settings(cls):


settings = Settings.get_settings()

# Init instance, so it is faster during requests
for value in enumerate(settings.pyris.llms.values()):
value[1].get_instance()
Empty file added app/llms/__init__.py
Empty file.
172 changes: 172 additions & 0 deletions app/llms/strategy_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import logging
from typing import Dict, Any

from guidance.llms import LLM, LLMSession

from app.config import settings, OpenAIConfig
from app.services.cache import cache_store
from app.services.circuit_breaker import CircuitBreaker

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)


class StrategyLLM(LLM):
llm_keys: list[str]
llm_configs: dict[str, OpenAIConfig]
llm_sessions: dict[str, LLMSession]

def __init__(self, llm_keys: list[str]):
super().__init__()
if llm_keys.__len__() == 0:
raise ValueError("No LLMs configured")
self.llm_keys = [
llm_keys
for llm_keys in llm_keys
if llm_keys in settings.pyris.llms
and isinstance(settings.pyris.llms[llm_keys], OpenAIConfig)
]
self.llm_configs = {
llm_key: settings.pyris.llms[llm_key] for llm_key in self.llm_keys
}
self.llm_instances = {
llm_key: llm_config.get_instance()
for llm_key, llm_config in self.llm_configs.items()
}
self.llm_sessions = {}

def first_llm(self):
return self.llm_instances[self.llm_keys[0]]

def extract_function_call(self, text):
return self.first_llm().extract_function_call(text)

def __call__(self, *args, asynchronous=False, **kwargs):
return self.first_llm()(*args, asynchronous=asynchronous, **kwargs)

def __getitem__(self, key):
return self.first_llm()[key]

def session(self, asynchronous=False):
return StrategyLLMSession(self)

def encode(self, string, **kwargs):
return self.first_llm().encode(string)

def decode(self, tokens, **kwargs):
return self.first_llm().decode(tokens)

def id_to_token(self, id):
return self.first_llm().id_to_token(id)

def token_to_id(self, token):
return self.first_llm().token_to_id(token)

def role_start(self, role_name, **kwargs):
return self.first_llm().role_start(role_name, **kwargs)

def role_end(self, role=None):
return self.first_llm().role_end(role)

def end_of_text(self):
return self.first_llm().end_of_text()


class StrategyLLMSession(LLMSession):
llm: StrategyLLM
current_session_key: str
current_session: LLMSession

def __init__(self, llm: StrategyLLM):
super().__init__(llm)
self.llm = llm

def get_total_context_length(
self, prompt: str, llm_key: str, max_tokens: int
):
prompt_token_length = (
self.llm.llm_instances[llm_key].encode(prompt).__len__()
)
return prompt_token_length + max_tokens

def set_correct_session(
self, prompt: str, max_tokens: int, exclude_llms=None
):
if exclude_llms is None:
exclude_llms = []

selected_llm = None

for llm_key in self.llm.llm_keys:
if (
llm_key not in exclude_llms
and cache_store.get(llm_key + ":status") != "OPEN"
and self.llm.llm_configs[llm_key].spec.context_length
>= self.get_total_context_length(prompt, llm_key, max_tokens)
):
selected_llm = llm_key
break

if selected_llm is None:
log.error(
"No viable LLMs found! Using LLM with longest context length."
)
llm_configs = {
llm_key: config
for llm_key, config in self.llm.llm_configs.items()
if llm_key not in exclude_llms
}
if llm_configs.__len__() == 0:
raise ValueError("All LLMs are excluded!")
selected_llm = max(
llm_configs,
key=lambda llm_key: llm_configs[llm_key].spec.context_length,
)

log.info("Selected LLM: " + selected_llm)

self.current_session_key = selected_llm
if selected_llm in self.llm.llm_sessions:
self.current_session = self.llm.llm_sessions[selected_llm]
else:
self.current_session = self.llm.llm_instances[
selected_llm
].session(asynchronous=True)
self.llm.llm_sessions[selected_llm] = self.current_session

async def __call__(self, *args, **kwargs):
prompt = args[0]
max_tokens = kwargs["max_tokens"]
log.info(
f"Trying to make request using strategy "
f"llm with llms [{self.llm.llm_keys}]"
)
self.set_correct_session(prompt, max_tokens)

def call():
return self.current_session(*args, **kwargs)

exclude_llms = []

while True:
try:
response = await CircuitBreaker.protected_call_async(
func=call, cache_key=self.current_session_key
)
return response
except Exception as e:
log.error(
f"Exception '{e}' while making request! "
f"Trying again with a different LLM."
)
exclude_llms.append(self.current_session_key)
self.set_correct_session(prompt, max_tokens, exclude_llms)

def __exit__(self, exc_type, exc_value, traceback):
return self.current_session.__exit__(exc_type, exc_value, traceback)

def _gen_key(self, args_dict):
return self.current_session._gen_key(args_dict)

def _cache_params(self, args_dict) -> Dict[str, Any]:
return self.current_session._cache_params(args_dict)
10 changes: 8 additions & 2 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timezone

from fastapi import APIRouter, Depends
from guidance._program_executor import SyntaxException
from parsimonious.exceptions import IncompleteParseError

from app.core.custom_exceptions import (
Expand Down Expand Up @@ -37,11 +38,16 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse:
content = CircuitBreaker.protected_call(
func=guidance.query,
cache_key=body.preferred_model,
accepted_exceptions=(KeyError, SyntaxError, IncompleteParseError),
accepted_exceptions=(
KeyError,
SyntaxError,
SyntaxException,
IncompleteParseError,
),
)
except KeyError as e:
raise MissingParameterException(str(e))
except (SyntaxError, IncompleteParseError) as e:
except (SyntaxError, SyntaxException, IncompleteParseError) as e:
raise InvalidTemplateException(str(e))
except Exception as e:
raise InternalServerException(str(e))
Expand Down
51 changes: 45 additions & 6 deletions app/services/circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,53 @@ def protected_call(
except accepted_exceptions as e:
raise e
except Exception as e:
num_failures = cache_store.incr(num_failures_key)
cache_store.expire(num_failures_key, cls.OPEN_TTL_SECONDS)
if num_failures >= cls.MAX_FAILURES:
cache_store.set(
status_key, cls.Status.OPEN, ex=cls.OPEN_TTL_SECONDS
)
cls.handle_exception(e, num_failures_key, status_key)

@classmethod
async def protected_call_async(
cls, func, cache_key: str, accepted_exceptions: tuple = ()
):
"""Wrap function call to avoid too many failures in a row.
Async version.

Params:
func: function to be called
cache_key: key to be used in the cache store
accepted_exceptions: exceptions that are not considered failures

Raises:
ValueError: if within the last OPEN_TTL_SECONDS seconds,
the function throws an exception for the MAX_FAILURES-th time.
"""

num_failures_key = f"{cache_key}:num_failures"
status_key = f"{cache_key}:status"

status = cache_store.get(status_key)
if status == cls.Status.OPEN:
raise ValueError("Too many failures! Please try again later.")

try:
response = await func()
cache_store.set(
status_key, cls.Status.CLOSED, ex=cls.CLOSED_TTL_SECONDS
)
return response
except accepted_exceptions as e:
log.error("Accepted error in protected_call for " + cache_key)
raise e
except Exception as e:
cls.handle_exception(e, num_failures_key, status_key)

@classmethod
def handle_exception(cls, e, num_failures_key, status_key):
num_failures = cache_store.incr(num_failures_key)
cache_store.expire(num_failures_key, cls.OPEN_TTL_SECONDS)
if num_failures >= cls.MAX_FAILURES:
cache_store.set(
status_key, cls.Status.OPEN, ex=cls.OPEN_TTL_SECONDS
)
raise e

@classmethod
def get_status(cls, checkhealth_func, cache_key: str):
Expand Down
3 changes: 1 addition & 2 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,4 @@ def is_up(self) -> bool:
return content == "1"

def _get_llm(self):
llm_credentials = self.model.llm_credentials
return guidance.llms.OpenAI(**llm_credentials)
return self.model.get_instance()
3 changes: 3 additions & 0 deletions application.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ pyris:

llms:
DUMMY:
type: "openai"
name: "Dummy model"
description: "Dummy model for testing"
spec:
context_length: 100
llm_credentials:
api_base: ""
api_type: ""
Expand Down
5 changes: 4 additions & 1 deletion application.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ pyris:

llms:
DUMMY:
type: "openai"
name: "Dummy model"
description: "Dummy model for testing"
spec:
context_length: 100
llm_credentials:
api_base: ""
api_type: ""
api_version: ""
deployment_id: ""
model: ""
model: "gpt-3.5-turbo"
token: ""
Loading
Loading