Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add caching if prompt request fails #148

Merged
merged 22 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4ddda3e
feat: create the dict cache and the method to go with it
Matthieu-OD Nov 14, 2024
fd9f462
feat: get_prompt add caching
Matthieu-OD Nov 14, 2024
8774a00
feat: implement caching on get_prompt
Matthieu-OD Nov 14, 2024
5d8b5f7
fix: ci
Matthieu-OD Nov 14, 2024
e7589c6
feat: add timeout if prompt cached
Matthieu-OD Nov 14, 2024
723f7fd
feat: improve caching
Matthieu-OD Nov 14, 2024
32b971f
feat: improve logging
Matthieu-OD Nov 14, 2024
5bcdce4
fix: ci errors
Matthieu-OD Nov 14, 2024
2476ebc
feat: improve the prompt cache class
Matthieu-OD Nov 15, 2024
0aec701
refactor: remove useless code
Matthieu-OD Nov 15, 2024
32b4e48
feat: implement the new SharedCachePrompt class
Matthieu-OD Nov 15, 2024
f5d460b
refactor: improve typing and move some logic
Matthieu-OD Nov 15, 2024
49fd140
feat: adds memory management to the SharedCachePrompt class
Matthieu-OD Nov 18, 2024
3e139f2
feat: add unit tests for SharedCachePrompt
Matthieu-OD Nov 18, 2024
3730581
feat: adds tests and updates run-test.sh
Matthieu-OD Nov 18, 2024
5318751
refactor: finishes the simplication
Matthieu-OD Nov 19, 2024
85c72d1
fix: test and implementation
Matthieu-OD Nov 20, 2024
6dfce9c
fix: add typing for sharedcache typing
Matthieu-OD Nov 20, 2024
06c5047
feat: align with literalai-typescript chagnes
Matthieu-OD Nov 21, 2024
e3c7ea0
Merge branch 'main' into matt/eng-2115-add-client-caching-for-prompts
Matthieu-OD Nov 28, 2024
cf98d74
fix: ci
Matthieu-OD Nov 28, 2024
c5faa02
fix: more ci fixes
Matthieu-OD Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions literalai/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing_extensions import deprecated
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -106,9 +105,6 @@
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import PaginatedResponse, User
Expand Down Expand Up @@ -145,7 +141,7 @@ class AsyncLiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

async def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
) -> Dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
Expand All @@ -158,7 +154,7 @@ def raise_error(error):
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=10,
timeout=timeout,
)

try:
Expand All @@ -179,13 +175,12 @@ def raise_error(error):

if json.get("data"):
if isinstance(json["data"], dict):
for _, value in json["data"].items():
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
value.get('message')}"""
)

return json

async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
Expand All @@ -211,15 +206,15 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
f"""Failed to parse JSON response: {
e}, content: {response.content!r}"""
)

async def gql_helper(
self,
query: str,
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = 10,
) -> R:
response = await self.make_gql_call(description, query, variables)
response = await self.make_gql_call(description, query, variables, timeout)
return process_response(response)

##################################################################################
Expand Down Expand Up @@ -447,7 +442,7 @@ async def upload_file(
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)

Expand Down Expand Up @@ -838,16 +833,32 @@ async def get_prompt(
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = None,
) -> "Prompt":
) -> Prompt:
if not (id or name):
raise ValueError("At least the `id` or the `name` must be provided.")

sync_api = LiteralAPI(self.api_key, self.url)
if id:
return await self.gql_helper(*get_prompt_helper(sync_api, id=id))
elif name:
return await self.gql_helper(
*get_prompt_helper(sync_api, name=name, version=version)
)
else:
raise ValueError("Either the `id` or the `name` must be provided.")
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
api=sync_api, id=id, name=name, version=version, cache=self.cache
)

try:
if id:
prompt = await self.gql_helper(
get_prompt_query, description, variables, process_response, timeout
)
elif name:
prompt = await self.gql_helper(
get_prompt_query, description, variables, process_response, timeout
)

return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt
raise e

async def update_prompt_ab_testing(
self, name: str, rollouts: List["PromptRollout"]
Expand Down
7 changes: 5 additions & 2 deletions literalai/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from literalai.my_types import Environment

from literalai.cache.shared_cache import SharedCache
from literalai.evaluation.dataset import DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperimentItem,
Expand Down Expand Up @@ -95,6 +96,8 @@ def __init__(
self.graphql_endpoint = self.url + "/api/graphql"
self.rest_endpoint = self.url + "/api"

self.cache = SharedCache()

@property
def headers(self):
from literalai.version import __version__
Expand Down Expand Up @@ -1011,9 +1014,9 @@ def get_prompt(
"""
Gets a prompt either by:
- `id`
- or `name` and (optional) `version`
- `name` and (optional) `version`

Either the `id` or the `name` must be provided.
At least the `id` or the `name` must be passed to the function.
If both are provided, the `id` is used.

Args:
Expand Down
51 changes: 39 additions & 12 deletions literalai/api/helpers/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict
from typing import TYPE_CHECKING, Optional, TypedDict, Callable

from literalai.observability.generation import GenerationMessage
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

from literalai.cache.prompt_helpers import put_prompt

if TYPE_CHECKING:
from literalai.api import LiteralAPI
from literalai.cache.shared_cache import SharedCache

from literalai.api.helpers import gql

Expand Down Expand Up @@ -36,9 +39,9 @@ def process_response(response):
def create_prompt_helper(
api: "LiteralAPI",
lineage_id: str,
template_messages: List[GenerationMessage],
template_messages: list[GenerationMessage],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
tools: Optional[list[dict]] = None,
):
variables = {
"lineageId": lineage_id,
Expand All @@ -56,28 +59,52 @@ def process_response(response):
return gql.CREATE_PROMPT_VERSION, description, variables, process_response


def get_prompt_cache_key(id: Optional[str], name: Optional[str], version: Optional[int]) -> str:
if id:
return id
elif name and version:
return f"{name}-{version}"
elif name:
return name
else:
raise ValueError("Either the `id` or the `name` must be provided.")


def get_prompt_helper(
api: "LiteralAPI",
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = 0,
):
cache: Optional["SharedCache"] = None,
) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]:
"""Helper function for getting prompts with caching logic"""

cached_prompt = None
timeout = 10

if cache:
cached_prompt = cache.get(get_prompt_cache_key(id, name, version))
timeout = 1 if cached_prompt else timeout

variables = {"id": id, "name": name, "version": version}

def process_response(response):
prompt = response["data"]["promptVersion"]
return Prompt.from_dict(api, prompt) if prompt else None
prompt_version = response["data"]["promptVersion"]
prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None
if cache and prompt:
put_prompt(cache, prompt)
return prompt

description = "get prompt"

return gql.GET_PROMPT_VERSION, description, variables, process_response
return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout, cached_prompt


def create_prompt_variant_helper(
from_lineage_id: Optional[str] = None,
template_messages: List[GenerationMessage] = [],
template_messages: list[GenerationMessage] = [],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
tools: Optional[list[dict]] = None,
):
variables = {
"fromLineageId": from_lineage_id,
Expand Down Expand Up @@ -105,7 +132,7 @@ def get_prompt_ab_testing_helper(
):
variables = {"lineageName": name}

def process_response(response) -> List[PromptRollout]:
def process_response(response) -> list[PromptRollout]:
response_data = response["data"]["promptLineageRollout"]
return list(map(lambda x: x["node"], response_data["edges"]))

Expand All @@ -114,10 +141,10 @@ def process_response(response) -> List[PromptRollout]:
return gql.GET_PROMPT_AB_TESTING, description, variables, process_response


def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
def update_prompt_ab_testing_helper(name: str, rollouts: list[PromptRollout]):
variables = {"name": name, "rollouts": rollouts}

def process_response(response) -> Dict:
def process_response(response) -> dict:
return response["data"]["updatePromptLineageRollout"]

description = "update prompt A/B testing"
Expand Down
45 changes: 28 additions & 17 deletions literalai/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing_extensions import deprecated
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -105,9 +104,6 @@
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import PaginatedResponse, User
Expand Down Expand Up @@ -144,8 +140,8 @@ class LiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
) -> Dict:
self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10
) -> dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
raise Exception(error)
Expand All @@ -156,7 +152,7 @@ def raise_error(error):
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=10,
timeout=timeout,
)

try:
Expand All @@ -177,7 +173,7 @@ def raise_error(error):

if json.get("data"):
if isinstance(json["data"], dict):
for _, value in json["data"].items():
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
Expand All @@ -186,7 +182,6 @@ def raise_error(error):

return json


def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
with httpx.Client(follow_redirects=True) as client:
response = client.post(
Expand Down Expand Up @@ -217,8 +212,9 @@ def gql_helper(
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = None,
) -> R:
response = self.make_gql_call(description, query, variables)
response = self.make_gql_call(description, query, variables, timeout)
return process_response(response)

##################################################################################
Expand Down Expand Up @@ -441,7 +437,7 @@ def upload_file(
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)

Expand Down Expand Up @@ -805,12 +801,27 @@ def get_prompt(
name: Optional[str] = None,
version: Optional[int] = None,
) -> "Prompt":
if id:
return self.gql_helper(*get_prompt_helper(self, id=id))
elif name:
return self.gql_helper(*get_prompt_helper(self, name=name, version=version))
else:
raise ValueError("Either the `id` or the `name` must be provided.")
if not (id or name):
raise ValueError("At least the `id` or the `name` must be provided.")

get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
api=self,id=id, name=name, version=version, cache=self.cache
)

try:
if id:
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)
elif name:
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)

return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt

raise e

def create_prompt_variant(
self,
Expand Down
Empty file added literalai/cache/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions literalai/cache/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from literalai.prompt_engineering.prompt import Prompt
from literalai.cache.shared_cache import SharedCache


def put_prompt(cache: SharedCache, prompt: Prompt):
cache.put(prompt.id, prompt)
cache.put(prompt.name, prompt)
cache.put(f"{prompt.name}-{prompt.version}", prompt)
Loading