Skip to content

Commit

Permalink
feat: adds tests and updates run-test.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu-OD committed Nov 18, 2024
1 parent 3e139f2 commit 896e303
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
21 changes: 10 additions & 11 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import time
import os
from threading import Lock
import uuid
Expand Down Expand Up @@ -147,20 +146,21 @@ class SharedPromptCache:
"""
Thread-safe singleton cache for storing prompts with memory leak prevention.
Only one instance will exist regardless of how many times it's instantiated.
Implements LRU eviction policy when cache reaches maximum size.
"""
_instance = None
_lock = Lock()
_prompts: dict[str, Prompt]
_name_index: dict[str, str]
_name_version_index: dict[tuple[str, int], str]

def __new__(cls, max_size: int = 1000):
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)

cls._instance._max_size = max_size
cls._instance._prompts: dict[str, Prompt] = {}
cls._instance._name_index: dict[str, str] = {}
cls._instance._name_version_index: dict[tuple[str, int], str] = {}
cls._instance._shared_cache = {}
cls._instance._name_index = {}
cls._instance._name_version_index = {}
return cls._instance

def get(
Expand All @@ -171,7 +171,6 @@ def get(
) -> Optional[Prompt]:
"""
Retrieves a prompt using the most specific criteria provided.
Updates access time for LRU tracking.
Lookup priority: id, name-version, name
"""
if id and not isinstance(id, str):
Expand All @@ -184,9 +183,9 @@ def get(
if id:
prompt_id = id
elif name and version:
prompt_id = self._name_version_index.get((name, version))
prompt_id = self._name_version_index.get((name, version)) or ""
elif name:
prompt_id = self._name_index.get(name)
prompt_id = self._name_index.get(name) or ""
else:
return None

Expand All @@ -196,7 +195,7 @@ def get(

def put(self, prompt: Prompt):
"""
Stores a prompt in the cache, managing size limits with LRU eviction.
Stores a prompt in the cache.
"""
if not isinstance(prompt, Prompt):
raise TypeError("Expected a Prompt object")
Expand Down
1 change: 1 addition & 0 deletions literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

if TYPE_CHECKING:
from literalai.api import LiteralAPI
from literalai.api import SharedPromptCache

from literalai.api import gql

Expand Down
2 changes: 1 addition & 1 deletion run-test.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v
LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v tests/e2e/ tests/unit/
14 changes: 14 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,20 @@ async def test_prompt(self, async_client: AsyncLiteralClient):

assert messages[0]["content"] == expected

@pytest.mark.timeout(5)
async def test_prompt_cache(self, async_client: AsyncLiteralClient):
prompt = await async_client.api.get_prompt(name="Default", version=0)
assert prompt is not None

original_key = async_client.api.api_key
async_client.api.api_key = "invalid-api-key"

cached_prompt = await async_client.api.get_prompt(name="Default", version=0)
assert cached_prompt is not None
assert cached_prompt.id == prompt.id

async_client.api.api_key = original_key

@pytest.mark.timeout(5)
async def test_prompt_ab_testing(self, client: LiteralClient):
prompt_name = "Python SDK E2E Tests"
Expand Down
16 changes: 6 additions & 10 deletions tests/unit/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
import random

from literalai.prompt_engineering.prompt import Prompt
from literalai.api import SharedPromptCache
from literalai.api import SharedPromptCache, LiteralAPI

def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt:
return Prompt(
api=None,
api=LiteralAPI(),
id=id,
name=name,
version=version,
created_at="",
updated_at="",
type="chat",
type="chat", # type: ignore
url="",
version_desc=None,
template_messages=[],
Expand All @@ -34,7 +34,7 @@ def test_singleton_instance():
def test_get_empty_cache():
"""Test getting from empty cache returns None"""
cache = SharedPromptCache()
cache.clear() # Ensure clean state
cache.clear()

assert cache._prompts == {}
assert cache._name_index == {}
Expand Down Expand Up @@ -90,12 +90,10 @@ def test_multiple_versions():
cache.put(prompt1)
cache.put(prompt2)

# Get specific versions
assert cache.get(name="test", version=1) is prompt1
assert cache.get(name="test", version=2) is prompt2

# Get by name should return latest version
assert cache.get(name="test") is prompt2 # Returns the last indexed version
assert cache.get(name="test") is prompt2

def test_clear_cache():
"""Test clearing the cache"""
Expand All @@ -114,7 +112,7 @@ def test_update_existing_prompt():
cache.clear()

prompt1 = default_prompt()
prompt2 = default_prompt(id="1", version=2) # Same ID, different version
prompt2 = default_prompt(id="1", version=2)

cache.put(prompt1)
cache.put(prompt2)
Expand All @@ -134,10 +132,8 @@ def test_lookup_priority():
cache.put(prompt1)
cache.put(prompt2)

# ID should take precedence
assert cache.get(id="1", name="test", version=2) is prompt1

# Name-version should take precedence over name
assert cache.get(name="test", version=2) is prompt2

def test_thread_safety():
Expand Down

0 comments on commit 896e303

Please sign in to comment.