diff --git a/hamilton/caching/stores/base.py b/hamilton/caching/stores/base.py index 901edb412..98eda2f5a 100644 --- a/hamilton/caching/stores/base.py +++ b/hamilton/caching/stores/base.py @@ -95,7 +95,7 @@ def set(self, cache_key: str, data_version: str, **kwargs) -> Optional[Any]: """ @abc.abstractmethod - def get(self, cache_key: str) -> Optional[str]: + def get(self, cache_key: str, **kwargs) -> Optional[str]: """Try to retrieve ``data_version`` keyed by ``cache_key``. If retrieval misses return ``None``. """ @@ -118,15 +118,19 @@ def exists(self, cache_key: str) -> bool: def get_run_ids(self) -> Sequence[str]: """Return a list of run ids, sorted from oldest to newest start time. A ``run_id`` is registered when the metadata_store ``.initialize()`` is called. - - NOTE because of race conditions, the order could theoretically differ from the - order stored on the SmartCacheAdapter `._run_ids` attribute. """ @abc.abstractmethod - def get_run(self, run_id: str) -> Any: - """Return all the metadata associated with a run. - The metadata content may differ across MetadataStore implementations + def get_run(self, run_id: str) -> Sequence[dict]: + """Return a list of node metadata associated with a run. + + For each node, the metadata should include ``cache_key`` (created or used) + and ``data_version``. These values allow to manually query the MetadataStore + or ResultStore. + + Decoding the ``cache_key`` gives the ``node_name``, ``code_version``, and + ``dependencies_data_versions``. Individual implementations may add more + information or decode the ``cache_key`` before returning metadata. """ @property diff --git a/hamilton/caching/stores/sqlite.py b/hamilton/caching/stores/sqlite.py index c234024bd..4ad0eb480 100644 --- a/hamilton/caching/stores/sqlite.py +++ b/hamilton/caching/stores/sqlite.py @@ -3,6 +3,7 @@ import threading from typing import List, Optional +from hamilton.caching.cache_key import decode_key from hamilton.caching.stores.base import MetadataStore @@ -19,14 +20,14 @@ def __init__( self._thread_local = threading.local() - def _get_connection(self): + def _get_connection(self) -> sqlite3.Connection: if not hasattr(self._thread_local, "connection"): self._thread_local.connection = sqlite3.connect( str(self._path), check_same_thread=False, **self.connection_kwargs ) return self._thread_local.connection - def _close_connection(self): + def _close_connection(self) -> None: if hasattr(self._thread_local, "connection"): self._thread_local.connection.close() del self._thread_local.connection @@ -76,9 +77,9 @@ def _create_tables_if_not_exists(self): """\ CREATE TABLE IF NOT EXISTS cache_metadata ( cache_key TEXT PRIMARY KEY, + data_version TEXT NOT NULL, node_name TEXT NOT NULL, code_version TEXT NOT NULL, - data_version TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (cache_key) REFERENCES history(cache_key) @@ -106,13 +107,21 @@ def set( self, *, cache_key: str, - node_name: str, - code_version: str, data_version: str, run_id: str, + node_name: str = None, + code_version: str = None, + **kwargs, ) -> None: cur = self.connection.cursor() + # if the caller of ``.set()`` directly provides the ``node_name`` and ``code_version``, + # we can skip the decoding step. + if (node_name is None) or (code_version is None): + decoded_key = decode_key(cache_key) + node_name = decoded_key["node_name"] + code_version = decoded_key["code_version"] + cur.execute("INSERT INTO history (cache_key, run_id) VALUES (?, ?)", (cache_key, run_id)) cur.execute( """\ @@ -150,7 +159,7 @@ def delete(self, cache_key: str) -> None: cur.execute("DELETE FROM cache_metadata WHERE cache_key = ?", (cache_key,)) self.connection.commit() - def delete_all(self): + def delete_all(self) -> None: """Delete all existing tables from the database""" cur = self.connection.cursor() @@ -170,35 +179,66 @@ def exists(self, cache_key: str) -> bool: return result is not None def get_run_ids(self) -> List[str]: + """Return a list of run ids, sorted from oldest to newest start time.""" cur = self.connection.cursor() - cur.execute("SELECT run_id FROM history ORDER BY id") + cur.execute("SELECT run_id FROM run_ids ORDER BY id") result = cur.fetchall() - if result is None: - raise IndexError("No `run_id` found. Table `history` is empty.") + return [r[0] for r in result] - return result[0] + def _run_exists(self, run_id: str) -> bool: + """Returns True if a run was initialized with ``run_id``, even + if the run recorded no node executions. + """ + cur = self.connection.cursor() + cur.execute( + """\ + SELECT EXISTS( + SELECT 1 + FROM run_ids + WHERE run_id = ? + ) + """, + (run_id,), + ) + result = cur.fetchone() + # SELECT EXISTS returns 1 for True, i.e., `run_id` is found + return result[0] == 1 def get_run(self, run_id: str) -> List[dict]: - """Return all the metadata associated with a run.""" + """Return a list of node metadata associated with a run. + + :param run_id: ID of the run to retrieve + :return: List of node metadata which includes ``cache_key``, ``data_version``, + ``node_name``, and ``code_version``. The list can be empty if a run was initialized + but no nodes were executed. + + :raises IndexError: if the ``run_id`` is not found in metadata store. + """ cur = self.connection.cursor() + if self._run_exists(run_id) is False: + raise IndexError(f"`run_id` not found in table `run_ids`: {run_id}") + cur.execute( """\ SELECT + cache_metadata.cache_key, + cache_metadata.data_version, cache_metadata.node_name, - cache_metadata.code_version, - cache_metadata.data_version - FROM (SELECT * FROM history WHERE history.run_id = ?) AS run_history - JOIN cache_metadata ON run_history.cache_key = cache_metadata.cache_key + cache_metadata.code_version + FROM history + JOIN cache_metadata ON history.cache_key = cache_metadata.cache_key + WHERE history.run_id = ? """, (run_id,), ) results = cur.fetchall() - - if results is None: - raise IndexError(f"`run_id` not found in table `history`: {run_id}") - return [ - dict(node_name=node_name, code_version=code_version, data_version=data_version) - for node_name, code_version, data_version in results + dict( + cache_key=cache_key, + data_version=data_version, + node_name=node_name, + code_version=code_version, + ) + for cache_key, data_version, node_name, code_version in results ] diff --git a/tests/caching/test_metadata_store.py b/tests/caching/test_metadata_store.py index ab3d61751..156e44b3e 100644 --- a/tests/caching/test_metadata_store.py +++ b/tests/caching/test_metadata_store.py @@ -98,3 +98,58 @@ def test_set_get_without_dependencies(metadata_store): retrieved_data_version = metadata_store.get(cache_key=cache_key) assert retrieved_data_version == data_version + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_get_run_ids_returns_ordered_list(metadata_store): + pre_run_ids = metadata_store.get_run_ids() + assert pre_run_ids == ["test-run-id"] # this is from the fixture + + metadata_store.initialize(run_id="foo") + metadata_store.initialize(run_id="bar") + metadata_store.initialize(run_id="baz") + + post_run_ids = metadata_store.get_run_ids() + assert post_run_ids == ["test-run-id", "foo", "bar", "baz"] + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_get_run_results_include_cache_key_and_data_version(metadata_store): + run_id = "test-run-id" + metadata_store.set( + cache_key="foo", + data_version="1", + run_id=run_id, + node_name="a", # kwarg specific to SQLiteMetadataStore + code_version="b", # kwarg specific to SQLiteMetadataStore + ) + metadata_store.set( + cache_key="bar", + data_version="2", + run_id=run_id, + node_name="a", # kwarg specific to SQLiteMetadataStore + code_version="b", # kwarg specific to SQLiteMetadataStore + ) + + run_info = metadata_store.get_run(run_id=run_id) + + assert isinstance(run_info, list) + assert len(run_info) == 2 + assert isinstance(run_info[1], dict) + assert run_info[0]["cache_key"] == "foo" + assert run_info[0]["data_version"] == "1" + assert run_info[1]["cache_key"] == "bar" + assert run_info[1]["data_version"] == "2" + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store): + metadata_store.initialize(run_id="foo") + run_info = metadata_store.get_run(run_id="foo") + assert run_info == [] + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_get_run_raises_error_if_run_id_not_found(metadata_store): + with pytest.raises(IndexError): + metadata_store.get_run(run_id="foo")