From 8c2ccd901fc0b756c2349a7c862a6fc1dde74471 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 25 Apr 2024 11:05:51 -0700 Subject: [PATCH 1/8] Added diskcache as a dependency to the client --- devtools/conda-envs/alchemiscale-client.yml | 1 + devtools/conda-envs/test.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/devtools/conda-envs/alchemiscale-client.yml b/devtools/conda-envs/alchemiscale-client.yml index f2ed56e1..734f0472 100644 --- a/devtools/conda-envs/alchemiscale-client.yml +++ b/devtools/conda-envs/alchemiscale-client.yml @@ -15,6 +15,7 @@ dependencies: - click - httpx - pydantic<2.0 + - diskcache ## user client printing - rich diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index 14b41d89..88eb8dec 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -10,6 +10,7 @@ dependencies: - openfe>=0.14.0 - openmmforcefields>=0.12.0 - pydantic<2.0 + - diskcache ## state store - neo4j-python-driver From 651a3faed3050e69fc7865b30a3974044a8ea980 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 25 Apr 2024 14:27:57 -0700 Subject: [PATCH 2/8] Use a diskcache.Cache to hold ProtocolDAGResults --- alchemiscale/base/client.py | 29 ++++++++++++++------ alchemiscale/interface/client.py | 45 ++++++++++++++++---------------- 2 files changed, 44 insertions(+), 30 deletions(-) diff --git a/alchemiscale/base/client.py b/alchemiscale/base/client.py index b3fd1fc7..7324f0e4 100644 --- a/alchemiscale/base/client.py +++ b/alchemiscale/base/client.py @@ -8,19 +8,19 @@ import time import random from itertools import islice -from typing import List import json from urllib.parse import urljoin from functools import wraps import gzip +from pathlib import Path +from diskcache import Cache import requests import httpx from gufe.tokenization import GufeTokenizable, JSON_HANDLER -from ..models import Scope, ScopedKey -from ..storage.models import TaskHub, Task +from ..models import ScopedKey def json_to_gufe(jsondata): @@ -61,6 +61,8 @@ def __init__( api_url: str, identifier: str, key: str, + cache_directory=Path.home() / ".cache" / "alchemiscale", + cache_size_limit: int = 1073741824, max_retries: int = 5, retry_base_seconds: float = 2.0, retry_max_seconds: float = 60.0, @@ -76,6 +78,10 @@ def __init__( Identifier for the identity used for authentication. key Credential for the identity used for authentication. + cache_directory + Location of the cache directory. Defaults to `${HOME}/.cache/alchemiscale`. + cache_size_limit + Maximum size of the client cache. Defaults to 1 GB. max_retries Maximum number of times to retry a request. In the case the API service is unresponsive an exponential backoff is applied with @@ -111,9 +117,17 @@ def __init__( self._session = None self._lock = None + self._cache = Cache( + cache_directory, + size_limit=cache_size_limit, + eviction_policy="least-recently-used", + ) + def _settings(self): return dict( api_url=self.api_url, + cache_directory=self._cache.directory, + cache_size_limit=self._cache.size_limit, identifier=self.identifier, key=self.key, max_retries=self.max_retries, @@ -357,7 +371,7 @@ def _get_resource(self, resource, params=None, compress=False): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason} : {detail}", @@ -392,7 +406,7 @@ async def _get_resource_async(self, resource, params=None, compress=False): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason_phrase} : {detail}", @@ -438,7 +452,7 @@ def _post(self, url, headers, data): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason} : {detail}", @@ -462,7 +476,7 @@ async def _post_resource_async(self, resource, data): if not 200 <= resp.status_code < 300: try: detail = resp.json()["detail"] - except: + except Exception: detail = resp.text raise self._exception( f"Status Code {resp.status_code} : {resp.reason_phrase} : {detail}", @@ -494,7 +508,6 @@ def _rich_waiting_columns(): @staticmethod def _rich_progress_columns(): from rich.progress import ( - Progress, SpinnerColumn, MofNCompleteColumn, TextColumn, diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index d6d68e9d..c202379e 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -8,32 +8,26 @@ from typing import Union, List, Dict, Optional, Tuple, Any, Iterable import json from itertools import chain -from collections import Counter from functools import lru_cache -import httpx from async_lru import alru_cache import networkx as nx from gufe import AlchemicalNetwork, Transformation, ChemicalSystem -from gufe.tokenization import GufeTokenizable, JSON_HANDLER, GufeKey +from gufe.tokenization import GufeTokenizable, JSON_HANDLER from gufe.protocols import ProtocolResult, ProtocolDAGResult from ..base.client import ( AlchemiscaleBaseClient, AlchemiscaleBaseClientError, - json_to_gufe, use_session, ) from ..models import Scope, ScopedKey from ..storage.models import ( - Task, - ProtocolDAGResultRef, TaskStatusEnum, NetworkStateEnum, ) from ..strategies import Strategy -from ..security.models import CredentialedUserIdentity from ..validators import validate_network_nonself from ..keyedchain import KeyedChain @@ -531,12 +525,7 @@ def _get_network(): return KeyedChain(content).to_gufe() if visualize: - from rich.progress import ( - Progress, - SpinnerColumn, - TimeElapsedColumn, - TextColumn, - ) + from rich.progress import Progress with Progress(*self._rich_waiting_columns(), transient=False) as progress: task = progress.add_task( @@ -587,7 +576,7 @@ def _get_transformation(): return KeyedChain(content).to_gufe() if visualize: - from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn + from rich.progress import Progress with Progress(*self._rich_waiting_columns(), transient=False) as progress: task = progress.add_task( @@ -1214,7 +1203,7 @@ async def _set_task_status( """Set the statuses for many Tasks""" data = dict(tasks=[t.dict() for t in tasks], status=status.value) tasks_updated = await self._post_resource_async( - f"/bulk/tasks/status/set", data=data + "/bulk/tasks/status/set", data=data ) return [ ScopedKey.from_str(task_sk) if task_sk is not None else None @@ -1287,7 +1276,7 @@ async def _set_task_priority( ) -> List[Optional[ScopedKey]]: data = dict(tasks=[t.dict() for t in tasks], priority=priority) tasks_updated = await self._post_resource_async( - f"/bulk/tasks/priority/set", data=data + "/bulk/tasks/priority/set", data=data ) return [ ScopedKey.from_str(task_sk) if task_sk is not None else None @@ -1328,7 +1317,7 @@ async def _get_task_priority(self, tasks: List[ScopedKey]) -> List[int]: """Get the priority for many Tasks""" data = dict(tasks=[t.dict() for t in tasks]) priorities = await self._post_resource_async( - f"/bulk/tasks/priority/get", data=data + "/bulk/tasks/priority/get", data=data ) return priorities @@ -1364,10 +1353,22 @@ def get_tasks_priority( async def _async_get_protocoldagresult( self, protocoldagresultref, transformation, route, compress ): - pdr_json = await self._get_resource_async( - f"/transformations/{transformation}/{route}/{protocoldagresultref}", - compress=compress, - ) + # check the disk cache for the PDR + if not ( + pdr_json := self._cache.get( + [str(transformation), route, str(protocoldagresultref)] + ) + ): + # query the alchemiscale server for the PDR + pdr_json = await self._get_resource_async( + f"/transformations/{transformation}/{route}/{protocoldagresultref}", + compress=compress, + ) + + # add the resulting PDR to the cache + self._cache.add( + [str(transformation), route, str(protocoldagresultref)], pdr_json + ) pdr = GufeTokenizable.from_dict( json.loads(pdr_json[0], cls=JSON_HANDLER.decoder) @@ -1397,7 +1398,7 @@ async def async_request(self): *self._rich_progress_columns(), transient=False ) as progress: task = progress.add_task( - f"Retrieving [bold]ProtocolDAGResult[/bold]s", + "Retrieving [bold]ProtocolDAGResult[/bold]s", total=len(protocoldagresultrefs), ) From 3761de534258612648822649d1e7d6200e3d740b Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 29 Apr 2024 09:48:17 -0700 Subject: [PATCH 3/8] Store raw bytes rather than pickled objects * The default Disk used by diskcache uses pickle when storing python objects. Instead, we are now storing byte arrays. Depending on the size of the byte array, this is either stored in the SQLite3 DB or or as a separate file if it's too large (>32 kb by default). * A test has been added that checks the hits and misses when pulling PDRs using the get_transformation_results method. The in-memory LRU cache is cleared manually for accurate stats. --- alchemiscale/interface/client.py | 17 ++++---- .../integration/interface/client/conftest.py | 19 ++++++--- .../interface/client/test_client.py | 41 +++++++++++++++++++ 3 files changed, 63 insertions(+), 14 deletions(-) diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index c202379e..d806e537 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -1354,25 +1354,24 @@ async def _async_get_protocoldagresult( self, protocoldagresultref, transformation, route, compress ): # check the disk cache for the PDR - if not ( - pdr_json := self._cache.get( - [str(transformation), route, str(protocoldagresultref)] - ) - ): + if pdr_json := self._cache.get(str(protocoldagresultref)): + pdr_json = pdr_json.decode("utf-8") + else: # query the alchemiscale server for the PDR pdr_json = await self._get_resource_async( f"/transformations/{transformation}/{route}/{protocoldagresultref}", compress=compress, ) + pdr_json = pdr_json[0] + # add the resulting PDR to the cache self._cache.add( - [str(transformation), route, str(protocoldagresultref)], pdr_json + str(protocoldagresultref), + pdr_json.encode("utf-8"), ) - pdr = GufeTokenizable.from_dict( - json.loads(pdr_json[0], cls=JSON_HANDLER.decoder) - ) + pdr = GufeTokenizable.from_dict(json.loads(pdr_json, cls=JSON_HANDLER.decoder)) return pdr diff --git a/alchemiscale/tests/integration/interface/client/conftest.py b/alchemiscale/tests/integration/interface/client/conftest.py index 7364b5a6..99784595 100644 --- a/alchemiscale/tests/integration/interface/client/conftest.py +++ b/alchemiscale/tests/integration/interface/client/conftest.py @@ -1,6 +1,8 @@ import pytest from copy import copy from time import sleep +import tempfile +from pathlib import Path import uvicorn import requests @@ -49,11 +51,18 @@ def uvicorn_server(user_api): @pytest.fixture(scope="module") def user_client(uvicorn_server, user_identity): - return client.AlchemiscaleClient( - api_url="http://127.0.0.1:8000/", - identifier=user_identity["identifier"], - key=user_identity["key"], - ) + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + test_client = client.AlchemiscaleClient( + api_url="http://127.0.0.1:8000/", + identifier=user_identity["identifier"], + key=user_identity["key"], + cache_directory=tmpdir, + ) + test_client._cache.stats(enable=True, reset=True) + + return test_client @pytest.fixture(scope="module") diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index 48dd868f..bf229132 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -1665,6 +1665,47 @@ def _execute_tasks(tasks, n4js, s3os_server): return protocoldagresults + def test_cached_pdr( + self, scope_test, n4js_preloaded, s3os_server, user_client, network_tyk2, tmpdir + ): + network_sk = user_client.get_scoped_key(network_tyk2, scope_test) + + transformation = list(t for t in network_tyk2.edges if "_solvent" in t.name)[0] + transformation_sk = user_client.get_scoped_key(transformation, scope_test) + + user_client.create_tasks(transformation_sk, count=3) + + all_tasks = user_client.get_transformation_tasks(transformation_sk) + actioned_tasks = user_client.action_tasks(all_tasks, network_sk) + + # execute the actioned tasks and push results directly using statestore and object store + with tmpdir.as_cwd(): + protocoldagresults = self._execute_tasks( + actioned_tasks, n4js_preloaded, s3os_server + ) + + # make sure that we have reset all stats tracking before the intial pull + assert user_client._cache.stats(reset=True) == (0, 0) + + user_client.get_transformation_results(transformation_sk) + + # we expect three misses, but now the cache has length 3 + assert user_client._cache.stats() == (0, 3) and len(user_client._cache) == 3 + + # clear the in-memory lru cache, to ensure we check the on-disk cache + user_client._async_get_protocoldagresult.cache_clear() + + # running again should now pull results from the on-disk cache + user_client.get_transformation_results(transformation_sk) + + assert user_client._cache.stats() == (3, 3) and len(user_client._cache) == 3 + + # when the alru is not cleared, we should not see misses or hits on the disk cache + # since the alru should populate from the results found on disk + user_client.get_transformation_results(transformation_sk) + + assert user_client._cache.stats() == (3, 3) and len(user_client._cache) == 3 + def test_get_transformation_and_network_results( self, scope_test, From c9a0d78a4178fcc7ecc02b7f97860584699652de Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 29 Apr 2024 09:54:21 -0700 Subject: [PATCH 4/8] Clean up of the test_client.py file --- .../interface/client/test_client.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index bf229132..7a87a7d2 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -3,17 +3,15 @@ from pathlib import Path from itertools import chain -from gufe import AlchemicalNetwork, ChemicalSystem, Transformation +from gufe import AlchemicalNetwork from gufe.tokenization import TOKENIZABLE_REGISTRY, GufeKey from gufe.protocols.protocoldag import execute_DAG -from gufe.tests.test_protocol import BrokenProtocol import networkx as nx from alchemiscale.models import ScopedKey, Scope from alchemiscale.storage.models import TaskStatusEnum, NetworkStateEnum from alchemiscale.storage.cypher import cypher_list_from_scoped_keys from alchemiscale.interface import client -from alchemiscale.utils import RegistryBackup from alchemiscale.tests.integration.interface.utils import ( get_user_settings_override, ) @@ -38,7 +36,7 @@ def test_refresh_credential( uvicorn_server, ): settings = get_user_settings_override() - assert user_client._jwtoken == None + assert user_client._jwtoken is None user_client._get_token() token = user_client._jwtoken @@ -835,8 +833,6 @@ def test_get_transformation_tasks( user_client: client.AlchemiscaleClient, network_tyk2, ): - n4js = n4js_preloaded - # select the transformation we want to compute an = network_tyk2 transformation = list(an.edges)[0] @@ -931,7 +927,7 @@ def test_get_scope_status( other_scope = Scope("other_org", "other_campaign", "other_project") n4js_preloaded.assemble_network(network_tyk2, other_scope) other_tf_sk = n4js_preloaded.query_transformations(scope=other_scope)[0] - task_sk = n4js_preloaded.create_task(other_tf_sk) + n4js_preloaded.create_task(other_tf_sk) # ask for the scope that we don't have access to status_counts = user_client.get_scope_status(other_scope) @@ -1109,8 +1105,6 @@ def test_get_network_actioned_tasks( network_tyk2, get_weights, ): - n4js = n4js_preloaded - an = network_tyk2 transformation = list(an.edges)[0] @@ -1226,7 +1220,6 @@ def test_get_task_actioned_networks( network_tyk2, actioned_tasks, ): - n4js = n4js_preloaded an = network_tyk2 transformation = list(an.edges)[0] @@ -1330,13 +1323,13 @@ def test_action_tasks_with_weights( if shouldfail: with pytest.raises(AlchemiscaleClientError): - actioned_sks = user_client.action_tasks( + user_client.action_tasks( task_sks, network_sk, weight, ) else: - actioned_sks = user_client.action_tasks( + user_client.action_tasks( task_sks, network_sk, weight, @@ -1457,7 +1450,6 @@ def test_set_tasks_status( an = network_tyk2 transformation = list(an.edges)[0] - network_sk = user_client.get_scoped_key(an, scope_test) transformation_sk = user_client.get_scoped_key(transformation, scope_test) all_tasks = user_client.create_tasks(transformation_sk, count=5) @@ -1498,7 +1490,6 @@ def test_get_tasks_status( an = network_tyk2 transformation = list(an.edges)[0] - network_sk = user_client.get_scoped_key(an, scope_test) transformation_sk = user_client.get_scoped_key(transformation, scope_test) all_tasks = user_client.create_tasks(transformation_sk, count=5) @@ -1551,7 +1542,6 @@ def test_set_tasks_priority( an = network_tyk2 transformation = list(an.edges)[0] - network_sk = user_client.get_scoped_key(an, scope_test) transformation_sk = user_client.get_scoped_key(transformation, scope_test) all_tasks = user_client.create_tasks(transformation_sk, count=5) @@ -1582,7 +1572,6 @@ def test_set_tasks_priority_missing_tasks( an = network_tyk2 transformation = list(an.edges)[0] - network_sk = user_client.get_scoped_key(an, scope_test) transformation_sk = user_client.get_scoped_key(transformation, scope_test) all_tasks = user_client.create_tasks(transformation_sk, count=5) @@ -1725,7 +1714,7 @@ def test_get_transformation_and_network_results( transformation_sk = user_client.get_scoped_key(transformation, scope_test) # user client : create three independent tasks for the transformation - tasks = user_client.create_tasks(transformation_sk, count=3) + user_client.create_tasks(transformation_sk, count=3) # user client : action the tasks for execution all_tasks = user_client.get_transformation_tasks(transformation_sk) @@ -1811,7 +1800,7 @@ def test_get_transformation_and_network_failures( raise Exception("Network out doesn't exactly match network in yet") else: break - except: + except Exception: sleep(0.1) tf_sks = user_client.get_network_transformations(network_sk) @@ -1950,7 +1939,7 @@ def test_get_task_failures( raise Exception("Network out doesn't exactly match network in yet") else: break - except: + except Exception: sleep(0.1) tf_sks = user_client.get_network_transformations(network_sk) From 471faf984fd6e8c982f01c578a2d9cf659a5d832 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 29 Apr 2024 12:50:18 -0700 Subject: [PATCH 5/8] Enable caching for other objects New objects supported: * Transformations * AlchemicalNetworks * ChemicalSystems * Generally anything that can be a KeyedChain --- alchemiscale/interface/client.py | 51 ++++++-- .../interface/client/test_client.py | 112 ++++++++++++++++-- 2 files changed, 147 insertions(+), 16 deletions(-) diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index d806e537..4690309e 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -491,6 +491,30 @@ def get_chemicalsystem_transformations( f"/chemicalsystems/{chemicalsystem}/transformations" ) + def _get_keyed_chain_resource(self, scopedkey: ScopedKey, get_content_function): + + content = None + + try: + cached_keyed_chain = self._cache.get(str(scopedkey), None).decode("utf-8") + content = json.loads(cached_keyed_chain, cls=JSON_HANDLER.decoder) + # JSON could not decode + except json.JSONDecodeError: + warn( + f"Error decoding cached {scopedkey.__qualname__} ({scopedkey}), deleting entry and retriving new content." + ) + self._cache.delete(str(scopedkey)) + # when trying to call the decode method with a None (i.e. cached entry not found) + except AttributeError: + pass + + if content is None: + content = get_content_function() + keyedchain_json = json.dumps(content, cls=JSON_HANDLER.encoder) + self._cache.add(str(scopedkey), keyedchain_json.encode("utf-8")) + + return KeyedChain(content).to_gufe() + @lru_cache(maxsize=100) def get_network( self, @@ -520,9 +544,12 @@ def get_network( """ + if isinstance(network, str): + network = ScopedKey.from_str(network) + def _get_network(): content = self._get_resource(f"/networks/{network}", compress=compress) - return KeyedChain(content).to_gufe() + return content if visualize: from rich.progress import Progress @@ -532,12 +559,12 @@ def _get_network(): f"Retrieving [bold]'{network}'[/bold]...", total=None ) - an = _get_network() + an = self._get_keyed_chain_resource(network, _get_network) progress.start_task(task) progress.update(task, total=1, completed=1) else: - an = _get_network() + an = self._get_keyed_chain_resource(network, _get_network) return an @lru_cache(maxsize=10000) @@ -569,11 +596,14 @@ def get_transformation( """ + if isinstance(transformation, str): + transformation = ScopedKey.from_str(transformation) + def _get_transformation(): content = self._get_resource( f"/transformations/{transformation}", compress=compress ) - return KeyedChain(content).to_gufe() + return content if visualize: from rich.progress import Progress @@ -583,11 +613,11 @@ def _get_transformation(): f"Retrieving [bold]'{transformation}'[/bold]...", total=None ) - tf = _get_transformation() + tf = self._get_keyed_chain_resource(transformation, _get_transformation) progress.start_task(task) progress.update(task, total=1, completed=1) else: - tf = _get_transformation() + tf = self._get_keyed_chain_resource(transformation, _get_transformation) return tf @@ -620,11 +650,14 @@ def get_chemicalsystem( """ + if isinstance(chemicalsystem, str): + chemicalsystem = ScopedKey.from_str(chemicalsystem) + def _get_chemicalsystem(): content = self._get_resource( f"/chemicalsystems/{chemicalsystem}", compress=compress ) - return KeyedChain(content).to_gufe() + return content if visualize: from rich.progress import Progress @@ -634,12 +667,12 @@ def _get_chemicalsystem(): f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None ) - cs = _get_chemicalsystem() + cs = self._get_keyed_chain_resource(chemicalsystem, _get_chemicalsystem) progress.start_task(task) progress.update(task, total=1, completed=1) else: - cs = _get_chemicalsystem() + cs = self._get_keyed_chain_resource(chemicalsystem, _get_chemicalsystem) return cs diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index 7a87a7d2..fa2e8cc3 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -481,6 +481,37 @@ def test_get_network( assert an == network_tyk2 assert an is network_tyk2 + def test_cached_network( + self, + scope_test, + n4js_preloaded, + network_tyk2, + user_client: client.AlchemiscaleClient, + ): + # clear both the on-disk and in-memory cache + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client.get_network.cache_clear() + + an_sk = user_client.get_scoped_key(network_tyk2, scope_test) + + # reset stats of cache + assert user_client._cache.stats(enable=True, reset=True) == (0, 0) + + # expect a miss and entry in the cache + user_client.get_network(an_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + + # expect the in-memory lru cache to get the last result pulled + user_client.get_network(an_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + # clear in-memory cache + user_client.get_network.cache_clear() + + # expect a hit + user_client.get_network(an_sk) + assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + def test_get_network_weight( self, scope_test, @@ -621,6 +652,37 @@ def test_get_transformation( assert tf == transformation assert tf is transformation + def test_cached_transformation( + self, + scope_test, + n4js_preloaded, + transformation, + user_client: client.AlchemiscaleClient, + ): + # clear both the on-disk and in-memory cache + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client.get_transformation.cache_clear() + + tf_sk = user_client.get_scoped_key(transformation, scope_test) + + # reset stats of cache + assert user_client._cache.stats(enable=True, reset=True) == (0, 0) + + # expect a miss and entry in the cache + user_client.get_transformation(tf_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + + # expect the in-memory lru cache to get the last result pulled + user_client.get_transformation(tf_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + # clear in-memory cache + user_client.get_transformation.cache_clear() + + # expect a hit + user_client.get_transformation(tf_sk) + assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + def test_get_chemicalsystem( self, scope_test, @@ -634,6 +696,37 @@ def test_get_chemicalsystem( assert cs == chemicalsystem assert cs is chemicalsystem + def test_cached_chemicalsystem( + self, + scope_test, + n4js_preloaded, + chemicalsystem, + user_client: client.AlchemiscaleClient, + ): + # clear both the on-disk and in-memory cache + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client.get_chemicalsystem.cache_clear() + + cs_sk = user_client.get_scoped_key(chemicalsystem, scope_test) + + # reset stats of cache + assert user_client._cache.stats(enable=True, reset=True) == (0, 0) + + # expect a miss and entry in the cache + user_client.get_chemicalsystem(cs_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + + # expect the in-memory lru cache to get the last result pulled + user_client.get_chemicalsystem(cs_sk) + assert user_client._cache.stats() == (0, 1) and len(user_client._cache) == 1 + # clear in-memory cache + user_client.get_chemicalsystem.cache_clear() + + # expect a hit + user_client.get_chemicalsystem(cs_sk) + assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + ### compute def test_create_tasks( @@ -1657,6 +1750,11 @@ def _execute_tasks(tasks, n4js, s3os_server): def test_cached_pdr( self, scope_test, n4js_preloaded, s3os_server, user_client, network_tyk2, tmpdir ): + + user_client._cache.clear() + user_client._cache.stats(reset=True) + user_client._async_get_protocoldagresult.cache_clear() + network_sk = user_client.get_scoped_key(network_tyk2, scope_test) transformation = list(t for t in network_tyk2.edges if "_solvent" in t.name)[0] @@ -1669,31 +1767,31 @@ def test_cached_pdr( # execute the actioned tasks and push results directly using statestore and object store with tmpdir.as_cwd(): - protocoldagresults = self._execute_tasks( - actioned_tasks, n4js_preloaded, s3os_server - ) + self._execute_tasks(actioned_tasks, n4js_preloaded, s3os_server) # make sure that we have reset all stats tracking before the intial pull assert user_client._cache.stats(reset=True) == (0, 0) user_client.get_transformation_results(transformation_sk) - # we expect three misses, but now the cache has length 3 - assert user_client._cache.stats() == (0, 3) and len(user_client._cache) == 3 + # we expect four misses, but now the cache has length 4 + # this is because the cache also captures the transformation, not just the PDRs + assert user_client._cache.stats() == (0, 4) and len(user_client._cache) == 4 # clear the in-memory lru cache, to ensure we check the on-disk cache user_client._async_get_protocoldagresult.cache_clear() + user_client.get_transformation.cache_clear() # running again should now pull results from the on-disk cache user_client.get_transformation_results(transformation_sk) - assert user_client._cache.stats() == (3, 3) and len(user_client._cache) == 3 + assert user_client._cache.stats() == (4, 4) and len(user_client._cache) == 4 # when the alru is not cleared, we should not see misses or hits on the disk cache # since the alru should populate from the results found on disk user_client.get_transformation_results(transformation_sk) - assert user_client._cache.stats() == (3, 3) and len(user_client._cache) == 3 + assert user_client._cache.stats() == (4, 4) and len(user_client._cache) == 4 def test_get_transformation_and_network_results( self, From 18c624839c8cf87a1bda49742d7f6208f2eff9d3 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 29 Apr 2024 15:10:33 -0700 Subject: [PATCH 6/8] Test cache data corruption and change test params * With known cached results, corrupt the values and make sure the user is warned that there was a problem with deserialization and that a new result will be downloaded. * Lowered the cache size limit for tests to avoid running out of space --- alchemiscale/interface/client.py | 2 +- .../integration/interface/client/conftest.py | 31 +++++++++------- .../interface/client/test_client.py | 36 +++++++++++++++++++ 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index 4690309e..aaa5d831 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -501,7 +501,7 @@ def _get_keyed_chain_resource(self, scopedkey: ScopedKey, get_content_function): # JSON could not decode except json.JSONDecodeError: warn( - f"Error decoding cached {scopedkey.__qualname__} ({scopedkey}), deleting entry and retriving new content." + f"Error decoding cached {scopedkey.qualname} ({scopedkey}), deleting entry and retriving new content." ) self._cache.delete(str(scopedkey)) # when trying to call the decode method with a None (i.e. cached entry not found) diff --git a/alchemiscale/tests/integration/interface/client/conftest.py b/alchemiscale/tests/integration/interface/client/conftest.py index 99784595..fc1a3334 100644 --- a/alchemiscale/tests/integration/interface/client/conftest.py +++ b/alchemiscale/tests/integration/interface/client/conftest.py @@ -49,20 +49,25 @@ def uvicorn_server(user_api): yield +@pytest.fixture(scope="session") +def cache_dir(tmp_path_factory): + cache_dir = tmp_path_factory.mktemp("alchemiscale-cache") + return cache_dir + + @pytest.fixture(scope="module") -def user_client(uvicorn_server, user_identity): - - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - test_client = client.AlchemiscaleClient( - api_url="http://127.0.0.1:8000/", - identifier=user_identity["identifier"], - key=user_identity["key"], - cache_directory=tmpdir, - ) - test_client._cache.stats(enable=True, reset=True) - - return test_client +def user_client(uvicorn_server, user_identity, cache_dir): + + test_client = client.AlchemiscaleClient( + api_url="http://127.0.0.1:8000/", + identifier=user_identity["identifier"], + key=user_identity["key"], + cache_directory=cache_dir, + cache_size_limit=int(1073741824 / 4), + ) + test_client._cache.stats(enable=True, reset=True) + + return test_client @pytest.fixture(scope="module") diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index fa2e8cc3..cf5f1747 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -512,6 +512,18 @@ def test_cached_network( user_client.get_network(an_sk) assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + user_client.get_network.cache_clear() + + # manually invalidate the cached network so it won't deserialize + cached_bytes = user_client._cache.get(str(an_sk)) + corrupted_bytes = cached_bytes.replace(b":", b";") + user_client._cache.set(str(an_sk), corrupted_bytes) + with pytest.warns(UserWarning, match=f"Error decoding cached {an_sk.qualname}"): + user_client.get_network(an_sk) + + new_cached_bytes = user_client._cache.get(str(an_sk)) + assert new_cached_bytes != corrupted_bytes + def test_get_network_weight( self, scope_test, @@ -683,6 +695,18 @@ def test_cached_transformation( user_client.get_transformation(tf_sk) assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + user_client.get_transformation.cache_clear() + + # manually invalidate the cached transformation so it won't deserialize + cached_bytes = user_client._cache.get(str(tf_sk)) + corrupted_bytes = cached_bytes.replace(b":", b";") + user_client._cache.set(str(tf_sk), corrupted_bytes) + with pytest.warns(UserWarning, match=f"Error decoding cached {tf_sk.qualname}"): + user_client.get_transformation(tf_sk) + + new_cached_bytes = user_client._cache.get(str(tf_sk)) + assert new_cached_bytes != corrupted_bytes + def test_get_chemicalsystem( self, scope_test, @@ -727,6 +751,18 @@ def test_cached_chemicalsystem( user_client.get_chemicalsystem(cs_sk) assert user_client._cache.stats() == (1, 1) and len(user_client._cache) == 1 + user_client.get_chemicalsystem.cache_clear() + + # manually invalidate the cached ChemicalSystem so it won't deserialize + cached_bytes = user_client._cache.get(str(cs_sk)) + corrupted_bytes = cached_bytes.replace(b":", b";") + user_client._cache.set(str(cs_sk), corrupted_bytes) + with pytest.warns(UserWarning, match=f"Error decoding cached {cs_sk.qualname}"): + user_client.get_chemicalsystem(cs_sk) + + new_cached_bytes = user_client._cache.get(str(cs_sk)) + assert new_cached_bytes != corrupted_bytes + ### compute def test_create_tasks( From 42366c800ffc87de09df6f437b0c595900c57586 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 30 Apr 2024 11:51:46 -0700 Subject: [PATCH 7/8] Clean up of interface/client/conftest.py * Removed unsused imports --- alchemiscale/tests/integration/interface/client/conftest.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/alchemiscale/tests/integration/interface/client/conftest.py b/alchemiscale/tests/integration/interface/client/conftest.py index fc1a3334..d633855a 100644 --- a/alchemiscale/tests/integration/interface/client/conftest.py +++ b/alchemiscale/tests/integration/interface/client/conftest.py @@ -1,14 +1,10 @@ import pytest from copy import copy -from time import sleep -import tempfile -from pathlib import Path import uvicorn -import requests from alchemiscale.settings import get_base_api_settings -from alchemiscale.base.api import get_n4js_depends, get_s3os_depends +from alchemiscale.base.api import get_s3os_depends from alchemiscale.interface import api, client from alchemiscale.tests.integration.interface.utils import get_user_settings_override From 3975cd459d5e81851ccd3d0bcc25ad517d4eccad Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 30 Apr 2024 13:50:17 -0700 Subject: [PATCH 8/8] Updated cache parameter handling in AlchemiscaleBaseClient The AlchemiscaleBaseClient now determines the cache directory when one is not specified directly (i.e. a None is provided to the AlchemiscaleBaseClient constructor). When a path to this directory is provided, it must be a string or pathlib.Path object. The logic for this operation lies in the `AlchemiscaleBaseClient._determine_cache_dir` method, which can raise a TypeError on invalid input. The `cache_size_limit` is now verified within the constructor to be >= 0. If it is not, then a ValueError is raised. New tests have been added for the above changes: * Negative cache_size_limit: checks for constructor-raised ValueError with a meaningful message. * cache_directory is None: checks output of the underlying _determine_cache_dir method with and without the XDG_CACHE_HOME environment variable. If we test it with the client constructor, the directory is made automatically, which we don't want in the tests as it may touch real data. * cache_directory is not None, str, or Path: Check that the constructor raises a TypeError with a meaningful message. --- alchemiscale/base/client.py | 33 +++++++++++++++-- .../interface/client/test_client.py | 35 +++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/alchemiscale/base/client.py b/alchemiscale/base/client.py index d2b1bce7..be5179af 100644 --- a/alchemiscale/base/client.py +++ b/alchemiscale/base/client.py @@ -13,6 +13,8 @@ from functools import wraps import gzip from pathlib import Path +import os +from typing import Union, Optional from diskcache import Cache import requests @@ -61,7 +63,7 @@ def __init__( api_url: str, identifier: str, key: str, - cache_directory=Path.home() / ".cache" / "alchemiscale", + cache_directory: Optional[Union[Path, str]] = None, cache_size_limit: int = 1073741824, max_retries: int = 5, retry_base_seconds: float = 2.0, @@ -79,7 +81,10 @@ def __init__( key Credential for the identity used for authentication. cache_directory - Location of the cache directory. Defaults to `${HOME}/.cache/alchemiscale`. + Location of the cache directory as either a `pathlib.Path` or `str`. + If `None` is provided then the directory will be determined via the + `XDG_CACHE_HOME` environment variable or default to + `${HOME}/.cache/alchemiscale`. Defaults to `None`. cache_size_limit Maximum size of the client cache. Defaults to 1 GB. max_retries @@ -117,12 +122,34 @@ def __init__( self._session = None self._lock = None + if cache_size_limit < 0: + raise ValueError( + "`cache_size_limit` must be greater than or equal to zero." + ) + self._cache = Cache( - cache_directory, + self._determine_cache_dir(cache_directory), size_limit=cache_size_limit, eviction_policy="least-recently-used", ) + @staticmethod + def _determine_cache_dir(cache_directory: Optional[Union[Path, str]]): + if not (isinstance(cache_directory, (Path, str)) or cache_directory is None): + raise TypeError( + "`cache_directory` must be a `str`, `pathlib.Path`, or `None`." + ) + + if cache_directory is None: + default_dir = Path().home() / ".cache" + cache_directory = ( + Path(os.getenv("XDG_CACHE_HOME", default_dir)) / "alchemiscale" + ) + else: + cache_directory = Path(cache_directory) + + return cache_directory.absolute() + def _settings(self): return dict( api_url=self.api_url, diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index 3be303a6..11106ae7 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -1,5 +1,6 @@ import pytest from time import sleep +import os from pathlib import Path from itertools import chain @@ -19,6 +20,40 @@ class TestClient: + def test_cache_size_limit_negative( + self, user_client: client.AlchemiscaleBaseClient + ): + settings = user_client._settings() + settings["cache_size_limit"] = -1 + with pytest.raises( + ValueError, + match="`cache_size_limit` must be greater than or equal to zero.", + ): + client.AlchemiscaleClient(**settings) + + def test_cache_dir_not_path_str_none(self, user_client: client.AlchemiscaleClient): + settings = user_client._settings() + settings["cache_directory"] = 0 + with pytest.raises( + TypeError, + match="`cache_directory` must be a `str`, `pathlib.Path`, or `None`.", + ): + client.AlchemiscaleClient(**settings) + + # here we test the AlchemiscaleClient._determine_cache_dir + # so we don't create non-temporary files on the testing platform + def test_cache_dir_none(self): + # set custom XDG_CACHE_HOME + target_dir = Path().home() / ".other_cache" + os.environ["XDG_CACHE_HOME"] = str(target_dir) + cache_dir = client.AlchemiscaleClient._determine_cache_dir(None) + assert cache_dir == target_dir.absolute() / "alchemiscale" + + # remove the env variable to get the default directory location + os.environ.pop("XDG_CACHE_HOME", None) + cache_dir = client.AlchemiscaleClient._determine_cache_dir(None) + assert cache_dir == Path().home() / ".cache" / "alchemiscale" + def test_wrong_credential( self, scope_test,