From 7382ef7c9a7b435f45587afe36a10f4373b58eeb Mon Sep 17 00:00:00 2001 From: m-rauen Date: Wed, 4 Dec 2024 18:05:28 -0300 Subject: [PATCH] change ignore to copy strategy for cache + add docstring --- overreact/_misc.py | 39 +++++++++++++++++++++++---------------- overreact/coords.py | 4 +--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/overreact/_misc.py b/overreact/_misc.py index 2cf3ba8..83f5e62 100644 --- a/overreact/_misc.py +++ b/overreact/_misc.py @@ -6,8 +6,8 @@ from __future__ import annotations import contextlib -import functools from functools import lru_cache as cache +from copy import deepcopy import numpy as np from scipy.stats import cauchy, norm @@ -15,22 +15,29 @@ import overreact as rx from overreact import _constants as constants -def ignore_unhashable(func): - """ +def copy_unhashable(maxsize=100000, typed=False): + """Creates a copy of the arrays received by lru_cache and make them hashable, therefore maintaining the arrays to be passed and caching prototypes of those arrays. + + Insipired by: + + + Parameters + ---------- + maxsize : int + typed : bool + If true, function arguments of different types will be cached separately. + + Returns + -------- + function """ - uncached = func.__wrapped__ - attributes = functools.WRAPPER_ASSIGNMENTS + ('cache_info', 'cache_clear') - @functools.wraps(func, assigned=attributes) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except TypeError as error: - if 'unhashable type' in str(error): - return uncached(*args, **kwargs) - raise - wrapper.__uncached__ = uncached - return wrapper - + def decorator(func): + cached_func = cache(maxsize=maxsize, typed=typed)(func) + def wrapper(*args, **kwargs): + return deepcopy(cached_func(*args, **kwargs)) + return wrapper + return decorator + def _find_package(package): """Check if a package exists without importing it. diff --git a/overreact/coords.py b/overreact/coords.py index 6a27dc4..3e2f545 100644 --- a/overreact/coords.py +++ b/overreact/coords.py @@ -2,7 +2,6 @@ # TODO(schneiderfelipe): add types to this module from __future__ import annotations -from functools import lru_cache as cache __all__ = ["find_point_group", "symmetry_number"] @@ -1683,8 +1682,7 @@ def gyradius(atommasses, atomcoords, method="iupac"): raise ValueError(msg) -@rx._misc.ignore_unhashable -@cache() +@rx._misc.copy_unhashable def inertia(atommasses, atomcoords, align=True): r"""Calculate primary moments and axes from the inertia tensor.