Skip to content

Commit

Permalink
change ignore to copy strategy for cache + add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
m-rauen committed Dec 4, 2024
1 parent 443317f commit 7382ef7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
39 changes: 23 additions & 16 deletions overreact/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,38 @@
from __future__ import annotations

import contextlib
import functools
from functools import lru_cache as cache
from copy import deepcopy

import numpy as np
from scipy.stats import cauchy, norm

import overreact as rx
from overreact import _constants as constants

def ignore_unhashable(func):
"""
def copy_unhashable(maxsize=100000, typed=False):
"""Creates a copy of the arrays received by lru_cache and make them hashable, therefore maintaining the arrays to be passed and caching prototypes of those arrays.
Insipired by:
<https://stackoverflow.com/a/54909677/21189559>
Parameters
----------
maxsize : int
typed : bool
If true, function arguments of different types will be cached separately.
Returns
--------
function
"""
uncached = func.__wrapped__
attributes = functools.WRAPPER_ASSIGNMENTS + ('cache_info', 'cache_clear')
@functools.wraps(func, assigned=attributes)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TypeError as error:
if 'unhashable type' in str(error):
return uncached(*args, **kwargs)
raise
wrapper.__uncached__ = uncached
return wrapper

def decorator(func):
cached_func = cache(maxsize=maxsize, typed=typed)(func)
def wrapper(*args, **kwargs):
return deepcopy(cached_func(*args, **kwargs))
return wrapper
return decorator

def _find_package(package):
"""Check if a package exists without importing it.
Expand Down
4 changes: 1 addition & 3 deletions overreact/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# TODO(schneiderfelipe): add types to this module
from __future__ import annotations
from functools import lru_cache as cache

__all__ = ["find_point_group", "symmetry_number"]

Expand Down Expand Up @@ -1683,8 +1682,7 @@ def gyradius(atommasses, atomcoords, method="iupac"):
raise ValueError(msg)


@rx._misc.ignore_unhashable
@cache()
@rx._misc.copy_unhashable
def inertia(atommasses, atomcoords, align=True):
r"""Calculate primary moments and axes from the inertia tensor.
Expand Down

0 comments on commit 7382ef7

Please sign in to comment.