Skip to content

ENH: Cache refactoring #1634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions astroquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from .logger import _init_log
from astropy import config as _config

__all__ = ["__version__", "__githash__", "__citation__", "__bibtex__", "test", "log"]

Expand All @@ -38,3 +39,23 @@ def _get_bibtex():
logging.addLevelName(5, "TRACE")
log = logging.getLogger()
log = _init_log()


# Set up cache configuration
class Cache_Conf(_config.ConfigNamespace):

cache_timeout = _config.ConfigItem(
604800,
('Astroquery-wide cache timeout (seconds). Default is 1 week (604800). '
'Setting to None prevents the cache from expiring (not recommended).'),
cfgtype='integer'
)

cache_active = _config.ConfigItem(
True,
"Astroquery global cache usage, False turns off all caching.",
cfgtype='boolean'
)


cache_conf = Cache_Conf()
92 changes: 66 additions & 26 deletions astroquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,31 @@
import requests
import textwrap

from datetime import datetime, timedelta
from pathlib import Path

from astropy.config import paths
from astroquery import log
import astropy.units as u
from astropy.utils.console import ProgressBarOrSpinner
import astropy.utils.data
from astropy.utils import deprecated

from astroquery import version, log, cache_conf
from astroquery.utils import system_tools

from . import version
from .utils import system_tools

__all__ = ['BaseQuery', 'QueryWithLogin']


def to_cache(response, cache_file):
log.debug("Caching data to {0}".format(cache_file))

response = copy.deepcopy(response)
if hasattr(response, 'request'):
for key in tuple(response.request.hooks.keys()):
del response.request.hooks[key]
with open(cache_file, "wb") as f:
pickle.dump(response, f)
pickle.dump(response, f, protocol=4)


def _replace_none_iterable(iterable):
Expand Down Expand Up @@ -102,20 +107,30 @@ def hash(self):
return self._hash

def request_file(self, cache_location):
fn = os.path.join(cache_location, self.hash() + ".pickle")
fn = cache_location.joinpath(self.hash() + ".pickle")
return fn

def from_cache(self, cache_location):
def from_cache(self, cache_location, cache_timeout):
request_file = self.request_file(cache_location)
try:
with open(request_file, "rb") as f:
response = pickle.load(f)
if not isinstance(response, requests.Response):
if cache_timeout is None:
expired = False
else:
current_time = datetime.utcnow()
cache_time = datetime.utcfromtimestamp(request_file.stat().st_mtime)
expired = current_time-cache_time > timedelta(seconds=cache_timeout)
if not expired:
with open(request_file, "rb") as f:
response = pickle.load(f)
if not isinstance(response, requests.Response):
response = None
else:
log.debug(f"Cache expired for {request_file}...")
response = None
except FileNotFoundError:
response = None
if response:
log.debug("Retrieving data from {0}".format(request_file))
log.debug("Retrieved data from {0}".format(request_file))
return response

def remove_cache_file(self, cache_location):
Expand All @@ -125,8 +140,8 @@ def remove_cache_file(self, cache_location):
"""
request_file = self.request_file(cache_location)

if os.path.exists(request_file):
os.remove(request_file)
if request_file.exists:
request_file.unlink()
else:
raise FileNotFoundError(f"Tried to remove cache file {request_file} but "
"it does not exist")
Expand Down Expand Up @@ -173,11 +188,8 @@ def __init__(self):
.format(vers=version.version,
olduseragent=S.headers['User-Agent']))

self.cache_location = os.path.join(
paths.get_cache_dir(), 'astroquery',
self.__class__.__name__.split("Class")[0])
os.makedirs(self.cache_location, exist_ok=True)
self._cache_active = True
self.name = self.__class__.__name__.split("Class")[0]
self._cache_location = None

def __call__(self, *args, **kwargs):
""" init a fresh copy of self """
Expand Down Expand Up @@ -217,9 +229,28 @@ def _response_hook(self, response, *args, **kwargs):
f"-----------------------------------------", '\t')
log.log(5, f"HTTP response\n{response_log}")

@property
def cache_location(self):
cl = self._cache_location or Path(paths.get_cache_dir(), 'astroquery', self.name)
cl.mkdir(parents=True, exist_ok=True)
return cl

@cache_location.setter
def cache_location(self, loc):
self._cache_location = Path(loc)

def reset_cache_location(self):
"""Resets the cache location to the default astropy cache"""
self._cache_location = None

def clear_cache(self):
"""Removes all cache files."""
for fle in self.cache_location.glob("*.pickle"):
fle.unlink()

def _request(self, method, url,
params=None, data=None, headers=None,
files=None, save=False, savedir='', timeout=None, cache=True,
files=None, save=False, savedir='', timeout=None, cache=None,
stream=False, auth=None, continuation=True, verify=True,
allow_redirects=True,
json=None, return_response_on_save=False):
Expand Down Expand Up @@ -253,6 +284,7 @@ def _request(self, method, url,
somewhere other than `BaseQuery.cache_location`
timeout : int
cache : bool
Optional, if specified, overrides global cache settings.
verify : bool
Verify the server's TLS certificate?
(see http://docs.python-requests.org/en/master/_modules/requests/sessions/?highlight=verify)
Expand All @@ -278,12 +310,16 @@ def _request(self, method, url,
is True.
"""

if cache is None: # Global caching not overridden
cache = cache_conf.cache_active

if save:
local_filename = url.split('/')[-1]
if os.name == 'nt':
# Windows doesn't allow special characters in filenames like
# ":" so replace them with an underscore
local_filename = local_filename.replace(':', '_')

local_filepath = os.path.join(savedir or self.cache_location or '.', local_filename)

response = self._download_file(url, local_filepath, cache=cache, timeout=timeout,
Expand All @@ -298,14 +334,14 @@ def _request(self, method, url,
else:
query = AstroQuery(method, url, params=params, data=data, headers=headers,
files=files, timeout=timeout, json=json)
if ((self.cache_location is None) or (not self._cache_active) or (not cache)):
with suspend_cache(self):
if not cache:
with cache_conf.set_temp("cache_active", False):
response = query.request(self._session, stream=stream,
auth=auth, verify=verify,
allow_redirects=allow_redirects,
json=json)
else:
response = query.from_cache(self.cache_location)
response = query.from_cache(self.cache_location, cache_conf.cache_timeout)
if not response:
response = query.request(self._session,
self.cache_location,
Expand All @@ -315,6 +351,7 @@ def _request(self, method, url,
verify=verify,
json=json)
to_cache(response, query.request_file(self.cache_location))

self._last_query = query
return response

Expand All @@ -336,6 +373,7 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
supports HTTP "range" requests, the download will be continued
where it left off.
cache : bool
Cache downloaded file. Defaults to False.
method : "GET" or "POST"
head_safe : bool
"""
Expand Down Expand Up @@ -439,19 +477,21 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
return response


@deprecated(since="v0.4.7", message=("The suspend_cache function is deprecated,"
"Use the conf set_temp function instead."))
class suspend_cache:
"""
A context manager that suspends caching.
"""

def __init__(self, obj):
self.obj = obj
def __init__(self, obj=None):
self.original_cache_setting = cache_conf.cache_active

def __enter__(self):
self.obj._cache_active = False
cache_conf.cache_active = False

def __exit__(self, exc_type, exc_value, traceback):
self.obj._cache_active = True
cache_conf.cache_active = self.original_cache_setting
return False


Expand Down Expand Up @@ -507,7 +547,7 @@ def _login(self, *args, **kwargs):
pass

def login(self, *args, **kwargs):
with suspend_cache(self):
with cache_conf.set_temp("cache_active", False):
self._authenticated = self._login(*args, **kwargs)
return self._authenticated

Expand Down
Loading