Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use diskcache for caching ProtocolDAGResults in the Alchemiscale client #271

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
56 changes: 48 additions & 8 deletions alchemiscale/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
import time
import random
from itertools import islice
from typing import List
import json
from urllib.parse import urljoin
from functools import wraps
import gzip
from pathlib import Path
import os
from typing import Union, Optional
from diskcache import Cache

import requests
import httpx

from gufe.tokenization import GufeTokenizable, JSON_HANDLER

from ..models import Scope, ScopedKey
from ..storage.models import TaskHub, Task
from ..models import ScopedKey


def json_to_gufe(jsondata):
Expand Down Expand Up @@ -61,6 +63,8 @@ def __init__(
api_url: str,
identifier: str,
key: str,
cache_directory: Optional[Union[Path, str]] = None,
cache_size_limit: int = 1073741824,
max_retries: int = 5,
retry_base_seconds: float = 2.0,
retry_max_seconds: float = 60.0,
Expand All @@ -76,6 +80,13 @@ def __init__(
Identifier for the identity used for authentication.
key
Credential for the identity used for authentication.
cache_directory
Location of the cache directory as either a `pathlib.Path` or `str`.
If `None` is provided then the directory will be determined via the
`XDG_CACHE_HOME` environment variable or default to
`${HOME}/.cache/alchemiscale`. Defaults to `None`.
cache_size_limit
Maximum size of the client cache. Defaults to 1 GB.
max_retries
Maximum number of times to retry a request. In the case the API
service is unresponsive an exponential backoff is applied with
Expand Down Expand Up @@ -111,9 +122,39 @@ def __init__(
self._session = None
self._lock = None

if cache_size_limit < 0:
raise ValueError(
"`cache_size_limit` must be greater than or equal to zero."
)

self._cache = Cache(
self._determine_cache_dir(cache_directory),
size_limit=cache_size_limit,
eviction_policy="least-recently-used",
)

@staticmethod
def _determine_cache_dir(cache_directory: Optional[Union[Path, str]]):
if not (isinstance(cache_directory, (Path, str)) or cache_directory is None):
raise TypeError(
"`cache_directory` must be a `str`, `pathlib.Path`, or `None`."
)

if cache_directory is None:
default_dir = Path().home() / ".cache"
cache_directory = (
Path(os.getenv("XDG_CACHE_HOME", default_dir)) / "alchemiscale"
)
else:
cache_directory = Path(cache_directory)

return cache_directory.absolute()

def _settings(self):
return dict(
api_url=self.api_url,
cache_directory=self._cache.directory,
cache_size_limit=self._cache.size_limit,
identifier=self.identifier,
key=self.key,
max_retries=self.max_retries,
Expand Down Expand Up @@ -361,7 +402,7 @@ def _get_resource(self, resource, params=None, compress=False):
if not 200 <= resp.status_code < 300:
try:
detail = resp.json()["detail"]
except:
except Exception:
detail = resp.text
raise self._exception(
f"Status Code {resp.status_code} : {resp.reason} : {detail}",
Expand Down Expand Up @@ -396,7 +437,7 @@ async def _get_resource_async(self, resource, params=None, compress=False):
if not 200 <= resp.status_code < 300:
try:
detail = resp.json()["detail"]
except:
except Exception:
detail = resp.text
raise self._exception(
f"Status Code {resp.status_code} : {resp.reason_phrase} : {detail}",
Expand Down Expand Up @@ -442,7 +483,7 @@ def _post(self, url, headers, data):
if not 200 <= resp.status_code < 300:
try:
detail = resp.json()["detail"]
except:
except Exception:
detail = resp.text
raise self._exception(
f"Status Code {resp.status_code} : {resp.reason} : {detail}",
Expand All @@ -466,7 +507,7 @@ async def _post_resource_async(self, resource, data):
if not 200 <= resp.status_code < 300:
try:
detail = resp.json()["detail"]
except:
except Exception:
detail = resp.text
raise self._exception(
f"Status Code {resp.status_code} : {resp.reason_phrase} : {detail}",
Expand Down Expand Up @@ -498,7 +539,6 @@ def _rich_waiting_columns():
@staticmethod
def _rich_progress_columns():
from rich.progress import (
Progress,
SpinnerColumn,
MofNCompleteColumn,
TextColumn,
Expand Down
76 changes: 60 additions & 16 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,30 @@ def get_chemicalsystem_transformations(
f"/chemicalsystems/{chemicalsystem}/transformations"
)

def _get_keyed_chain_resource(self, scopedkey: ScopedKey, get_content_function):

content = None

try:
cached_keyed_chain = self._cache.get(str(scopedkey), None).decode("utf-8")
content = json.loads(cached_keyed_chain, cls=JSON_HANDLER.decoder)
# JSON could not decode
except json.JSONDecodeError:
warn(
f"Error decoding cached {scopedkey.qualname} ({scopedkey}), deleting entry and retriving new content."
)
self._cache.delete(str(scopedkey))
# when trying to call the decode method with a None (i.e. cached entry not found)
except AttributeError:
pass

if content is None:
content = get_content_function()
keyedchain_json = json.dumps(content, cls=JSON_HANDLER.encoder)
self._cache.add(str(scopedkey), keyedchain_json.encode("utf-8"))

return KeyedChain(content).to_gufe()

@lru_cache(maxsize=100)
def get_network(
self,
Expand Down Expand Up @@ -520,9 +544,12 @@ def get_network(

"""

if isinstance(network, str):
network = ScopedKey.from_str(network)

def _get_network():
content = self._get_resource(f"/networks/{network}", compress=compress)
return KeyedChain(content).to_gufe()
return content

if visualize:
from rich.progress import Progress
Expand All @@ -532,12 +559,12 @@ def _get_network():
f"Retrieving [bold]'{network}'[/bold]...", total=None
)

an = _get_network()
an = self._get_keyed_chain_resource(network, _get_network)

progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
an = _get_network()
an = self._get_keyed_chain_resource(network, _get_network)
return an

@lru_cache(maxsize=10000)
Expand Down Expand Up @@ -569,11 +596,14 @@ def get_transformation(

"""

if isinstance(transformation, str):
transformation = ScopedKey.from_str(transformation)

def _get_transformation():
content = self._get_resource(
f"/transformations/{transformation}", compress=compress
)
return KeyedChain(content).to_gufe()
return content

if visualize:
from rich.progress import Progress
Expand All @@ -583,11 +613,11 @@ def _get_transformation():
f"Retrieving [bold]'{transformation}'[/bold]...", total=None
)

tf = _get_transformation()
tf = self._get_keyed_chain_resource(transformation, _get_transformation)
progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
tf = _get_transformation()
tf = self._get_keyed_chain_resource(transformation, _get_transformation)

return tf

Expand Down Expand Up @@ -620,11 +650,14 @@ def get_chemicalsystem(

"""

if isinstance(chemicalsystem, str):
chemicalsystem = ScopedKey.from_str(chemicalsystem)

def _get_chemicalsystem():
content = self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
return KeyedChain(content).to_gufe()
return content

if visualize:
from rich.progress import Progress
Expand All @@ -634,12 +667,12 @@ def _get_chemicalsystem():
f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None
)

cs = _get_chemicalsystem()
cs = self._get_keyed_chain_resource(chemicalsystem, _get_chemicalsystem)

progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
cs = _get_chemicalsystem()
cs = self._get_keyed_chain_resource(chemicalsystem, _get_chemicalsystem)

return cs

Expand Down Expand Up @@ -1353,14 +1386,25 @@ def get_tasks_priority(
async def _async_get_protocoldagresult(
self, protocoldagresultref, transformation, route, compress
):
pdr_json = await self._get_resource_async(
f"/transformations/{transformation}/{route}/{protocoldagresultref}",
compress=compress,
)
# check the disk cache for the PDR
if pdr_json := self._cache.get(str(protocoldagresultref)):
pdr_json = pdr_json.decode("utf-8")
else:
# query the alchemiscale server for the PDR
pdr_json = await self._get_resource_async(
f"/transformations/{transformation}/{route}/{protocoldagresultref}",
compress=compress,
)

pdr = GufeTokenizable.from_dict(
json.loads(pdr_json[0], cls=JSON_HANDLER.decoder)
)
pdr_json = pdr_json[0]

# add the resulting PDR to the cache
self._cache.add(
str(protocoldagresultref),
pdr_json.encode("utf-8"),
)

pdr = GufeTokenizable.from_dict(json.loads(pdr_json, cls=JSON_HANDLER.decoder))

return pdr

Expand Down
20 changes: 15 additions & 5 deletions alchemiscale/tests/integration/interface/client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pytest
from copy import copy
from time import sleep

import uvicorn
import requests

from alchemiscale.settings import get_base_api_settings
from alchemiscale.base.api import get_n4js_depends, get_s3os_depends
from alchemiscale.base.api import get_s3os_depends
from alchemiscale.interface import api, client

from alchemiscale.tests.integration.interface.utils import get_user_settings_override
Expand Down Expand Up @@ -47,13 +45,25 @@ def uvicorn_server(user_api):
yield


@pytest.fixture(scope="session")
def cache_dir(tmp_path_factory):
cache_dir = tmp_path_factory.mktemp("alchemiscale-cache")
return cache_dir


@pytest.fixture(scope="module")
def user_client(uvicorn_server, user_identity):
return client.AlchemiscaleClient(
def user_client(uvicorn_server, user_identity, cache_dir):

test_client = client.AlchemiscaleClient(
api_url="http://127.0.0.1:8000/",
identifier=user_identity["identifier"],
key=user_identity["key"],
cache_directory=cache_dir,
cache_size_limit=int(1073741824 / 4),
)
test_client._cache.stats(enable=True, reset=True)

return test_client


@pytest.fixture(scope="module")
Expand Down
Loading
Loading