Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/github_actions/actions/setup-pyth…
Browse files Browse the repository at this point in the history
…on-4
  • Loading branch information
Hialus authored Oct 6, 2023
2 parents 62b0edf + ca33d8e commit 602cad0
Show file tree
Hide file tree
Showing 14 changed files with 815 additions and 509 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ RUN poetry install
COPY app /app/app

EXPOSE 8000
CMD ["poetry", "run", "gunicorn", "app.main:app", "-w", "8", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]
CMD ["poetry", "run", "gunicorn", "app.main:app", "--workers", "8", "--threads", "8", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]
54 changes: 52 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,59 @@
import os

from guidance.llms import OpenAI
from pyaml_env import parse_config
from pydantic import BaseModel
from pydantic import BaseModel, validator, Field, typing


class LLMModelSpecs(BaseModel):
context_length: int


class LLMModelConfig(BaseModel):
type: str
name: str
description: str

def get_instance(cls):
raise NotImplementedError()


class OpenAIConfig(LLMModelConfig):
spec: LLMModelSpecs
llm_credentials: dict
instance: typing.Any = Field(repr=False)

@validator("type")
def check_type(cls, v):
if v != "openai":
raise ValueError("Invalid type:" + v + " != openai")
return v

def get_instance(cls):
if cls.instance is not None:
return cls.instance
cls.instance = OpenAI(**cls.llm_credentials)
return cls.instance


class StrategyLLMConfig(LLMModelConfig):
llms: list[str]
instance: typing.Any = Field(repr=False)

@validator("type")
def check_type(cls, v):
if v != "strategy":
raise ValueError("Invalid type:" + v + " != strategy")
return v

def get_instance(cls):
if cls.instance is not None:
return cls.instance
# Local import needed to avoid circular dependency
from app.llms.strategy_llm import StrategyLLM

cls.instance = StrategyLLM(cls.llms)
return cls.instance


class APIKeyConfig(BaseModel):
Expand All @@ -27,7 +73,7 @@ class CacheParams(BaseModel):
class Settings(BaseModel):
class PyrisSettings(BaseModel):
api_keys: list[APIKeyConfig]
llms: dict[str, LLMModelConfig]
llms: dict[str, OpenAIConfig | StrategyLLMConfig]
cache: CacheSettings

pyris: PyrisSettings
Expand All @@ -44,3 +90,7 @@ def get_settings(cls):


settings = Settings.get_settings()

# Init instance, so it is faster during requests
for value in enumerate(settings.pyris.llms.values()):
value[1].get_instance()
Empty file added app/llms/__init__.py
Empty file.
172 changes: 172 additions & 0 deletions app/llms/strategy_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import logging
from typing import Dict, Any

from guidance.llms import LLM, LLMSession

from app.config import settings, OpenAIConfig
from app.services.cache import cache_store
from app.services.circuit_breaker import CircuitBreaker

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)


class StrategyLLM(LLM):
llm_keys: list[str]
llm_configs: dict[str, OpenAIConfig]
llm_sessions: dict[str, LLMSession]

def __init__(self, llm_keys: list[str]):
super().__init__()
if llm_keys.__len__() == 0:
raise ValueError("No LLMs configured")
self.llm_keys = [
llm_keys
for llm_keys in llm_keys
if llm_keys in settings.pyris.llms
and isinstance(settings.pyris.llms[llm_keys], OpenAIConfig)
]
self.llm_configs = {
llm_key: settings.pyris.llms[llm_key] for llm_key in self.llm_keys
}
self.llm_instances = {
llm_key: llm_config.get_instance()
for llm_key, llm_config in self.llm_configs.items()
}
self.llm_sessions = {}

def first_llm(self):
return self.llm_instances[self.llm_keys[0]]

def extract_function_call(self, text):
return self.first_llm().extract_function_call(text)

def __call__(self, *args, asynchronous=False, **kwargs):
return self.first_llm()(*args, asynchronous=asynchronous, **kwargs)

def __getitem__(self, key):
return self.first_llm()[key]

def session(self, asynchronous=False):
return StrategyLLMSession(self)

def encode(self, string, **kwargs):
return self.first_llm().encode(string)

def decode(self, tokens, **kwargs):
return self.first_llm().decode(tokens)

def id_to_token(self, id):
return self.first_llm().id_to_token(id)

def token_to_id(self, token):
return self.first_llm().token_to_id(token)

def role_start(self, role_name, **kwargs):
return self.first_llm().role_start(role_name, **kwargs)

def role_end(self, role=None):
return self.first_llm().role_end(role)

def end_of_text(self):
return self.first_llm().end_of_text()


class StrategyLLMSession(LLMSession):
llm: StrategyLLM
current_session_key: str
current_session: LLMSession

def __init__(self, llm: StrategyLLM):
super().__init__(llm)
self.llm = llm

def get_total_context_length(
self, prompt: str, llm_key: str, max_tokens: int
):
prompt_token_length = (
self.llm.llm_instances[llm_key].encode(prompt).__len__()
)
return prompt_token_length + max_tokens

def set_correct_session(
self, prompt: str, max_tokens: int, exclude_llms=None
):
if exclude_llms is None:
exclude_llms = []

selected_llm = None

for llm_key in self.llm.llm_keys:
if (
llm_key not in exclude_llms
and cache_store.get(llm_key + ":status") != "OPEN"
and self.llm.llm_configs[llm_key].spec.context_length
>= self.get_total_context_length(prompt, llm_key, max_tokens)
):
selected_llm = llm_key
break

if selected_llm is None:
log.error(
"No viable LLMs found! Using LLM with longest context length."
)
llm_configs = {
llm_key: config
for llm_key, config in self.llm.llm_configs.items()
if llm_key not in exclude_llms
}
if llm_configs.__len__() == 0:
raise ValueError("All LLMs are excluded!")
selected_llm = max(
llm_configs,
key=lambda llm_key: llm_configs[llm_key].spec.context_length,
)

log.info("Selected LLM: " + selected_llm)

self.current_session_key = selected_llm
if selected_llm in self.llm.llm_sessions:
self.current_session = self.llm.llm_sessions[selected_llm]
else:
self.current_session = self.llm.llm_instances[
selected_llm
].session(asynchronous=True)
self.llm.llm_sessions[selected_llm] = self.current_session

async def __call__(self, *args, **kwargs):
prompt = args[0]
max_tokens = kwargs["max_tokens"]
log.info(
f"Trying to make request using strategy "
f"llm with llms [{self.llm.llm_keys}]"
)
self.set_correct_session(prompt, max_tokens)

def call():
return self.current_session(*args, **kwargs)

exclude_llms = []

while True:
try:
response = await CircuitBreaker.protected_call_async(
func=call, cache_key=self.current_session_key
)
return response
except Exception as e:
log.error(
f"Exception '{e}' while making request! "
f"Trying again with a different LLM."
)
exclude_llms.append(self.current_session_key)
self.set_correct_session(prompt, max_tokens, exclude_llms)

def __exit__(self, exc_type, exc_value, traceback):
return self.current_session.__exit__(exc_type, exc_value, traceback)

def _gen_key(self, args_dict):
return self.current_session._gen_key(args_dict)

def _cache_params(self, args_dict) -> Dict[str, Any]:
return self.current_session._cache_params(args_dict)
10 changes: 8 additions & 2 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timezone

from fastapi import APIRouter, Depends
from guidance._program_executor import SyntaxException
from parsimonious.exceptions import IncompleteParseError

from app.core.custom_exceptions import (
Expand Down Expand Up @@ -37,11 +38,16 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse:
content = CircuitBreaker.protected_call(
func=guidance.query,
cache_key=body.preferred_model,
accepted_exceptions=(KeyError, SyntaxError, IncompleteParseError),
accepted_exceptions=(
KeyError,
SyntaxError,
SyntaxException,
IncompleteParseError,
),
)
except KeyError as e:
raise MissingParameterException(str(e))
except (SyntaxError, IncompleteParseError) as e:
except (SyntaxError, SyntaxException, IncompleteParseError) as e:
raise InvalidTemplateException(str(e))
except Exception as e:
raise InternalServerException(str(e))
Expand Down
51 changes: 45 additions & 6 deletions app/services/circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,53 @@ def protected_call(
except accepted_exceptions as e:
raise e
except Exception as e:
num_failures = cache_store.incr(num_failures_key)
cache_store.expire(num_failures_key, cls.OPEN_TTL_SECONDS)
if num_failures >= cls.MAX_FAILURES:
cache_store.set(
status_key, cls.Status.OPEN, ex=cls.OPEN_TTL_SECONDS
)
cls.handle_exception(e, num_failures_key, status_key)

@classmethod
async def protected_call_async(
cls, func, cache_key: str, accepted_exceptions: tuple = ()
):
"""Wrap function call to avoid too many failures in a row.
Async version.
Params:
func: function to be called
cache_key: key to be used in the cache store
accepted_exceptions: exceptions that are not considered failures
Raises:
ValueError: if within the last OPEN_TTL_SECONDS seconds,
the function throws an exception for the MAX_FAILURES-th time.
"""

num_failures_key = f"{cache_key}:num_failures"
status_key = f"{cache_key}:status"

status = cache_store.get(status_key)
if status == cls.Status.OPEN:
raise ValueError("Too many failures! Please try again later.")

try:
response = await func()
cache_store.set(
status_key, cls.Status.CLOSED, ex=cls.CLOSED_TTL_SECONDS
)
return response
except accepted_exceptions as e:
log.error("Accepted error in protected_call for " + cache_key)
raise e
except Exception as e:
cls.handle_exception(e, num_failures_key, status_key)

@classmethod
def handle_exception(cls, e, num_failures_key, status_key):
num_failures = cache_store.incr(num_failures_key)
cache_store.expire(num_failures_key, cls.OPEN_TTL_SECONDS)
if num_failures >= cls.MAX_FAILURES:
cache_store.set(
status_key, cls.Status.OPEN, ex=cls.OPEN_TTL_SECONDS
)
raise e

@classmethod
def get_status(cls, checkhealth_func, cache_key: str):
Expand Down
3 changes: 1 addition & 2 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,4 @@ def is_up(self) -> bool:
return content == "1"

def _get_llm(self):
llm_credentials = self.model.llm_credentials
return guidance.llms.OpenAI(**llm_credentials)
return self.model.get_instance()
3 changes: 3 additions & 0 deletions application.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ pyris:

llms:
DUMMY:
type: "openai"
name: "Dummy model"
description: "Dummy model for testing"
spec:
context_length: 100
llm_credentials:
api_base: ""
api_type: ""
Expand Down
5 changes: 4 additions & 1 deletion application.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ pyris:

llms:
DUMMY:
type: "openai"
name: "Dummy model"
description: "Dummy model for testing"
spec:
context_length: 100
llm_credentials:
api_base: ""
api_type: ""
api_version: ""
deployment_id: ""
model: ""
model: "gpt-3.5-turbo"
token: ""
Loading

0 comments on commit 602cad0

Please sign in to comment.