Skip to content

Commit

Permalink
fix: caching SQLiteMetadataStore.get_run_ids() (#1205)
Browse files Browse the repository at this point in the history
* fixed .get_run_ids() and standardized .get_run() + tests

* fixed docstrings formatting errors

---------

Co-authored-by: zilto <tjean@DESKTOP-V6JDCS2>
  • Loading branch information
zilto and zilto authored Oct 24, 2024
1 parent 09f0712 commit 2ff4de5
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 28 deletions.
18 changes: 11 additions & 7 deletions hamilton/caching/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def set(self, cache_key: str, data_version: str, **kwargs) -> Optional[Any]:
"""

@abc.abstractmethod
def get(self, cache_key: str) -> Optional[str]:
def get(self, cache_key: str, **kwargs) -> Optional[str]:
"""Try to retrieve ``data_version`` keyed by ``cache_key``.
If retrieval misses return ``None``.
"""
Expand All @@ -118,15 +118,19 @@ def exists(self, cache_key: str) -> bool:
def get_run_ids(self) -> Sequence[str]:
"""Return a list of run ids, sorted from oldest to newest start time.
A ``run_id`` is registered when the metadata_store ``.initialize()`` is called.
NOTE because of race conditions, the order could theoretically differ from the
order stored on the SmartCacheAdapter `._run_ids` attribute.
"""

@abc.abstractmethod
def get_run(self, run_id: str) -> Any:
"""Return all the metadata associated with a run.
The metadata content may differ across MetadataStore implementations
def get_run(self, run_id: str) -> Sequence[dict]:
"""Return a list of node metadata associated with a run.
For each node, the metadata should include ``cache_key`` (created or used)
and ``data_version``. These values allow to manually query the MetadataStore
or ResultStore.
Decoding the ``cache_key`` gives the ``node_name``, ``code_version``, and
``dependencies_data_versions``. Individual implementations may add more
information or decode the ``cache_key`` before returning metadata.
"""

@property
Expand Down
82 changes: 61 additions & 21 deletions hamilton/caching/stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
from typing import List, Optional

from hamilton.caching.cache_key import decode_key
from hamilton.caching.stores.base import MetadataStore


Expand All @@ -19,14 +20,14 @@ def __init__(

self._thread_local = threading.local()

def _get_connection(self):
def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._thread_local, "connection"):
self._thread_local.connection = sqlite3.connect(
str(self._path), check_same_thread=False, **self.connection_kwargs
)
return self._thread_local.connection

def _close_connection(self):
def _close_connection(self) -> None:
if hasattr(self._thread_local, "connection"):
self._thread_local.connection.close()
del self._thread_local.connection
Expand Down Expand Up @@ -76,9 +77,9 @@ def _create_tables_if_not_exists(self):
"""\
CREATE TABLE IF NOT EXISTS cache_metadata (
cache_key TEXT PRIMARY KEY,
data_version TEXT NOT NULL,
node_name TEXT NOT NULL,
code_version TEXT NOT NULL,
data_version TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (cache_key) REFERENCES history(cache_key)
Expand Down Expand Up @@ -106,13 +107,21 @@ def set(
self,
*,
cache_key: str,
node_name: str,
code_version: str,
data_version: str,
run_id: str,
node_name: str = None,
code_version: str = None,
**kwargs,
) -> None:
cur = self.connection.cursor()

# if the caller of ``.set()`` directly provides the ``node_name`` and ``code_version``,
# we can skip the decoding step.
if (node_name is None) or (code_version is None):
decoded_key = decode_key(cache_key)
node_name = decoded_key["node_name"]
code_version = decoded_key["code_version"]

cur.execute("INSERT INTO history (cache_key, run_id) VALUES (?, ?)", (cache_key, run_id))
cur.execute(
"""\
Expand Down Expand Up @@ -150,7 +159,7 @@ def delete(self, cache_key: str) -> None:
cur.execute("DELETE FROM cache_metadata WHERE cache_key = ?", (cache_key,))
self.connection.commit()

def delete_all(self):
def delete_all(self) -> None:
"""Delete all existing tables from the database"""
cur = self.connection.cursor()

Expand All @@ -170,35 +179,66 @@ def exists(self, cache_key: str) -> bool:
return result is not None

def get_run_ids(self) -> List[str]:
"""Return a list of run ids, sorted from oldest to newest start time."""
cur = self.connection.cursor()
cur.execute("SELECT run_id FROM history ORDER BY id")
cur.execute("SELECT run_id FROM run_ids ORDER BY id")
result = cur.fetchall()

if result is None:
raise IndexError("No `run_id` found. Table `history` is empty.")
return [r[0] for r in result]

return result[0]
def _run_exists(self, run_id: str) -> bool:
"""Returns True if a run was initialized with ``run_id``, even
if the run recorded no node executions.
"""
cur = self.connection.cursor()
cur.execute(
"""\
SELECT EXISTS(
SELECT 1
FROM run_ids
WHERE run_id = ?
)
""",
(run_id,),
)
result = cur.fetchone()
# SELECT EXISTS returns 1 for True, i.e., `run_id` is found
return result[0] == 1

def get_run(self, run_id: str) -> List[dict]:
"""Return all the metadata associated with a run."""
"""Return a list of node metadata associated with a run.
:param run_id: ID of the run to retrieve
:return: List of node metadata which includes ``cache_key``, ``data_version``,
``node_name``, and ``code_version``. The list can be empty if a run was initialized
but no nodes were executed.
:raises IndexError: if the ``run_id`` is not found in metadata store.
"""
cur = self.connection.cursor()
if self._run_exists(run_id) is False:
raise IndexError(f"`run_id` not found in table `run_ids`: {run_id}")

cur.execute(
"""\
SELECT
cache_metadata.cache_key,
cache_metadata.data_version,
cache_metadata.node_name,
cache_metadata.code_version,
cache_metadata.data_version
FROM (SELECT * FROM history WHERE history.run_id = ?) AS run_history
JOIN cache_metadata ON run_history.cache_key = cache_metadata.cache_key
cache_metadata.code_version
FROM history
JOIN cache_metadata ON history.cache_key = cache_metadata.cache_key
WHERE history.run_id = ?
""",
(run_id,),
)
results = cur.fetchall()

if results is None:
raise IndexError(f"`run_id` not found in table `history`: {run_id}")

return [
dict(node_name=node_name, code_version=code_version, data_version=data_version)
for node_name, code_version, data_version in results
dict(
cache_key=cache_key,
data_version=data_version,
node_name=node_name,
code_version=code_version,
)
for cache_key, data_version, node_name, code_version in results
]
55 changes: 55 additions & 0 deletions tests/caching/test_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,58 @@ def test_set_get_without_dependencies(metadata_store):
retrieved_data_version = metadata_store.get(cache_key=cache_key)

assert retrieved_data_version == data_version


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_ids_returns_ordered_list(metadata_store):
pre_run_ids = metadata_store.get_run_ids()
assert pre_run_ids == ["test-run-id"] # this is from the fixture

metadata_store.initialize(run_id="foo")
metadata_store.initialize(run_id="bar")
metadata_store.initialize(run_id="baz")

post_run_ids = metadata_store.get_run_ids()
assert post_run_ids == ["test-run-id", "foo", "bar", "baz"]


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_results_include_cache_key_and_data_version(metadata_store):
run_id = "test-run-id"
metadata_store.set(
cache_key="foo",
data_version="1",
run_id=run_id,
node_name="a", # kwarg specific to SQLiteMetadataStore
code_version="b", # kwarg specific to SQLiteMetadataStore
)
metadata_store.set(
cache_key="bar",
data_version="2",
run_id=run_id,
node_name="a", # kwarg specific to SQLiteMetadataStore
code_version="b", # kwarg specific to SQLiteMetadataStore
)

run_info = metadata_store.get_run(run_id=run_id)

assert isinstance(run_info, list)
assert len(run_info) == 2
assert isinstance(run_info[1], dict)
assert run_info[0]["cache_key"] == "foo"
assert run_info[0]["data_version"] == "1"
assert run_info[1]["cache_key"] == "bar"
assert run_info[1]["data_version"] == "2"


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store):
metadata_store.initialize(run_id="foo")
run_info = metadata_store.get_run(run_id="foo")
assert run_info == []


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_raises_error_if_run_id_not_found(metadata_store):
with pytest.raises(IndexError):
metadata_store.get_run(run_id="foo")

0 comments on commit 2ff4de5

Please sign in to comment.