Skip to content

Commit

Permalink
feat: get_prompt add caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu-OD committed Nov 14, 2024
1 parent 9f20152 commit 6136af7
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,31 +195,41 @@ def _get_prompt_cache_key(
) -> str:
key = ""
if id:
key = f"id:{id}"
return f"id:{id}"
elif name:
key = f"name:{name}"
if version:
key += f":version:{version}"
return key
else:
raise ValueError("Either the `id` or the `name` must be provided.")

if version:
key += f":version:{version}"
def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None) -> Optional[Prompt]:
"""Returns the cached prompt, using key in this order: id, name-version, name
"""
key_id = self._get_prompt_cache_key(id=id)
if key_id in self._prompt_cache:
return self._prompt_cache.get(key_id)

return key
key_name_version = self._get_prompt_cache_key(name=name, version=version)
if key_name_version in self._prompt_cache:
return self._prompt_cache.get(key_name_version)

def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None):
key = self._get_prompt_cache_key(id, name, version)
# handle the case where I have a version but it's not found in this case look without the version
if key in self._prompt_cache:
return self._prompt_cache.get(key)
else:
key_without_version = self._get_prompt_cache_key(id, name)
return self._prompt_cache.get(key_without_version)
key_name = self._get_prompt_cache_key(name=name)
if key_name in self._prompt_cache:
return self._prompt_cache.get(key_name)

def _create_prompt_cache(self, prompt: Prompt):
key = self._get_prompt_cache_key(id=prompt.id, name=prompt.name, version=prompt.version)
key_without_version = self._get_prompt_cache_key(id=prompt.id, name=prompt.name)
self._prompt_cache[key] = prompt
self._prompt_cache[key_without_version] = prompt
"""Creates cache for prompt. 3 entries are created/updated: id, name, name:version
"""
key_id = self._get_prompt_cache_key(id=prompt.id)
self._prompt_cache[key_id] = prompt

key_name = self._get_prompt_cache_key(name=prompt.name)
self._prompt_cache[key_name] = prompt

key_name_version = self._get_prompt_cache_key(name=prompt.name, version=prompt.version)
self._prompt_cache[key_name_version] = prompt


class LiteralAPI(BaseLiteralAPI):
Expand Down Expand Up @@ -1398,13 +1408,28 @@ def get_prompt(
Returns:
Prompt: The prompt with the given identifier or name.
"""
if id:
return self.gql_helper(*get_prompt_helper(self, id=id))
elif name:
return self.gql_helper(*get_prompt_helper(self, name=name, version=version))
else:
if not (id or name):
raise ValueError("Either the `id` or the `name` must be provided.")

cached_prompt = self._get_prompt_cache(id, name)

try:
if id:
prompt = self.gql_helper(*get_prompt_helper(self, id=id))
elif name:
prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version))

self._create_prompt_cache(prompt)
return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
logger.error(f"Error: {e}")
return cached_prompt

raise e

def create_prompt_variant(
self,
name: str,
Expand Down Expand Up @@ -2629,16 +2654,30 @@ async def get_prompt(
name: Optional[str] = None,
version: Optional[int] = None,
) -> Prompt:
sync_api = LiteralAPI(self.api_key, self.url)
if id:
return await self.gql_helper(*get_prompt_helper(sync_api, id=id))
elif name:
return await self.gql_helper(
*get_prompt_helper(sync_api, name=name, version=version)
)
else:
if not (id or name):
raise ValueError("Either the `id` or the `name` must be provided.")

sync_api = LiteralAPI(self.api_key, self.url)
cached_prompt = self._get_prompt_cache(id, name)

try:
if id:
prompt = await self.gql_helper(*get_prompt_helper(sync_api, id=id))
elif name:
prompt = await self.gql_helper(
*get_prompt_helper(sync_api, name=name, version=version)
)

self._create_prompt_cache(prompt)
return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
logger.error(f"Error: {e}")
return cached_prompt
raise e

get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__

async def update_prompt_ab_testing(
Expand Down

0 comments on commit 6136af7

Please sign in to comment.