Skip to content

Commit

Permalink
feat: add caching if prompt request fails (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu-OD authored Nov 29, 2024
1 parent 5f4c92a commit 87274ae
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 52 deletions.
51 changes: 31 additions & 20 deletions literalai/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing_extensions import deprecated
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -106,9 +105,6 @@
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import PaginatedResponse, User
Expand Down Expand Up @@ -145,7 +141,7 @@ class AsyncLiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

async def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
) -> Dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
Expand All @@ -158,7 +154,7 @@ def raise_error(error):
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=10,
timeout=timeout,
)

try:
Expand All @@ -179,13 +175,12 @@ def raise_error(error):

if json.get("data"):
if isinstance(json["data"], dict):
for _, value in json["data"].items():
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
value.get('message')}"""
)

return json

async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
Expand All @@ -211,15 +206,15 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
f"""Failed to parse JSON response: {
e}, content: {response.content!r}"""
)

async def gql_helper(
self,
query: str,
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = 10,
) -> R:
response = await self.make_gql_call(description, query, variables)
response = await self.make_gql_call(description, query, variables, timeout)
return process_response(response)

##################################################################################
Expand Down Expand Up @@ -447,7 +442,7 @@ async def upload_file(
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)

Expand Down Expand Up @@ -838,16 +833,32 @@ async def get_prompt(
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = None,
) -> "Prompt":
) -> Prompt:
if not (id or name):
raise ValueError("At least the `id` or the `name` must be provided.")

sync_api = LiteralAPI(self.api_key, self.url)
if id:
return await self.gql_helper(*get_prompt_helper(sync_api, id=id))
elif name:
return await self.gql_helper(
*get_prompt_helper(sync_api, name=name, version=version)
)
else:
raise ValueError("Either the `id` or the `name` must be provided.")
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
api=sync_api, id=id, name=name, version=version, cache=self.cache
)

try:
if id:
prompt = await self.gql_helper(
get_prompt_query, description, variables, process_response, timeout
)
elif name:
prompt = await self.gql_helper(
get_prompt_query, description, variables, process_response, timeout
)

return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt
raise e

async def update_prompt_ab_testing(
self, name: str, rollouts: List["PromptRollout"]
Expand Down
7 changes: 5 additions & 2 deletions literalai/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from literalai.my_types import Environment

from literalai.cache.shared_cache import SharedCache
from literalai.evaluation.dataset import DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperimentItem,
Expand Down Expand Up @@ -95,6 +96,8 @@ def __init__(
self.graphql_endpoint = self.url + "/api/graphql"
self.rest_endpoint = self.url + "/api"

self.cache = SharedCache()

@property
def headers(self):
from literalai.version import __version__
Expand Down Expand Up @@ -1011,9 +1014,9 @@ def get_prompt(
"""
Gets a prompt either by:
- `id`
- or `name` and (optional) `version`
- `name` and (optional) `version`
Either the `id` or the `name` must be provided.
At least the `id` or the `name` must be passed to the function.
If both are provided, the `id` is used.
Args:
Expand Down
51 changes: 39 additions & 12 deletions literalai/api/helpers/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict
from typing import TYPE_CHECKING, Optional, TypedDict, Callable

from literalai.observability.generation import GenerationMessage
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

from literalai.cache.prompt_helpers import put_prompt

if TYPE_CHECKING:
from literalai.api import LiteralAPI
from literalai.cache.shared_cache import SharedCache

from literalai.api.helpers import gql

Expand Down Expand Up @@ -36,9 +39,9 @@ def process_response(response):
def create_prompt_helper(
api: "LiteralAPI",
lineage_id: str,
template_messages: List[GenerationMessage],
template_messages: list[GenerationMessage],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
tools: Optional[list[dict]] = None,
):
variables = {
"lineageId": lineage_id,
Expand All @@ -56,28 +59,52 @@ def process_response(response):
return gql.CREATE_PROMPT_VERSION, description, variables, process_response


def get_prompt_cache_key(id: Optional[str], name: Optional[str], version: Optional[int]) -> str:
if id:
return id
elif name and version:
return f"{name}-{version}"
elif name:
return name
else:
raise ValueError("Either the `id` or the `name` must be provided.")


def get_prompt_helper(
api: "LiteralAPI",
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = 0,
):
cache: Optional["SharedCache"] = None,
) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]:
"""Helper function for getting prompts with caching logic"""

cached_prompt = None
timeout = 10

if cache:
cached_prompt = cache.get(get_prompt_cache_key(id, name, version))
timeout = 1 if cached_prompt else timeout

variables = {"id": id, "name": name, "version": version}

def process_response(response):
prompt = response["data"]["promptVersion"]
return Prompt.from_dict(api, prompt) if prompt else None
prompt_version = response["data"]["promptVersion"]
prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None
if cache and prompt:
put_prompt(cache, prompt)
return prompt

description = "get prompt"

return gql.GET_PROMPT_VERSION, description, variables, process_response
return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout, cached_prompt


def create_prompt_variant_helper(
from_lineage_id: Optional[str] = None,
template_messages: List[GenerationMessage] = [],
template_messages: list[GenerationMessage] = [],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
tools: Optional[list[dict]] = None,
):
variables = {
"fromLineageId": from_lineage_id,
Expand Down Expand Up @@ -105,7 +132,7 @@ def get_prompt_ab_testing_helper(
):
variables = {"lineageName": name}

def process_response(response) -> List[PromptRollout]:
def process_response(response) -> list[PromptRollout]:
response_data = response["data"]["promptLineageRollout"]
return list(map(lambda x: x["node"], response_data["edges"]))

Expand All @@ -114,10 +141,10 @@ def process_response(response) -> List[PromptRollout]:
return gql.GET_PROMPT_AB_TESTING, description, variables, process_response


def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
def update_prompt_ab_testing_helper(name: str, rollouts: list[PromptRollout]):
variables = {"name": name, "rollouts": rollouts}

def process_response(response) -> Dict:
def process_response(response) -> dict:
return response["data"]["updatePromptLineageRollout"]

description = "update prompt A/B testing"
Expand Down
45 changes: 28 additions & 17 deletions literalai/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing_extensions import deprecated
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -105,9 +104,6 @@
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import PaginatedResponse, User
Expand Down Expand Up @@ -144,8 +140,8 @@ class LiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
) -> Dict:
self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10
) -> dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
raise Exception(error)
Expand All @@ -156,7 +152,7 @@ def raise_error(error):
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=10,
timeout=timeout,
)

try:
Expand All @@ -177,7 +173,7 @@ def raise_error(error):

if json.get("data"):
if isinstance(json["data"], dict):
for _, value in json["data"].items():
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
Expand All @@ -186,7 +182,6 @@ def raise_error(error):

return json


def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
with httpx.Client(follow_redirects=True) as client:
response = client.post(
Expand Down Expand Up @@ -217,8 +212,9 @@ def gql_helper(
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = None,
) -> R:
response = self.make_gql_call(description, query, variables)
response = self.make_gql_call(description, query, variables, timeout)
return process_response(response)

##################################################################################
Expand Down Expand Up @@ -441,7 +437,7 @@ def upload_file(
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)

Expand Down Expand Up @@ -805,12 +801,27 @@ def get_prompt(
name: Optional[str] = None,
version: Optional[int] = None,
) -> "Prompt":
if id:
return self.gql_helper(*get_prompt_helper(self, id=id))
elif name:
return self.gql_helper(*get_prompt_helper(self, name=name, version=version))
else:
raise ValueError("Either the `id` or the `name` must be provided.")
if not (id or name):
raise ValueError("At least the `id` or the `name` must be provided.")

get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
api=self,id=id, name=name, version=version, cache=self.cache
)

try:
if id:
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)
elif name:
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)

return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt

raise e

def create_prompt_variant(
self,
Expand Down
Empty file added literalai/cache/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions literalai/cache/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from literalai.prompt_engineering.prompt import Prompt
from literalai.cache.shared_cache import SharedCache


def put_prompt(cache: SharedCache, prompt: Prompt):
cache.put(prompt.id, prompt)
cache.put(prompt.name, prompt)
cache.put(f"{prompt.name}-{prompt.version}", prompt)
Loading

0 comments on commit 87274ae

Please sign in to comment.