diff --git a/overreact/_misc.py b/overreact/_misc.py index cacaa5e..82a3ae4 100644 --- a/overreact/_misc.py +++ b/overreact/_misc.py @@ -6,6 +6,7 @@ from __future__ import annotations import contextlib +import functools from functools import lru_cache as cache import numpy as np @@ -14,6 +15,19 @@ import overreact as rx from overreact import _constants as constants +def ignore_unhashable(func): + uncached = func.__wrapped__ + attributes = functools.WRAPPER_ASSIGNMENTS + ('cache_info', 'cache_clear') + @functools.wraps(func, assigned=attributes) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except TypeError as error: + if 'unhashable type' in str(error): + return uncached(*args, **kwargs) + raise + wrapper.__uncached__ = uncached + return wrapper def _find_package(package): """Check if a package exists without importing it. diff --git a/overreact/coords.py b/overreact/coords.py index cb54e84..9603bb4 100644 --- a/overreact/coords.py +++ b/overreact/coords.py @@ -2,7 +2,7 @@ # TODO(schneiderfelipe): add types to this module from __future__ import annotations -import functools +from functools import lru_cache as cache __all__ = ["find_point_group", "symmetry_number"] @@ -1681,21 +1681,9 @@ def gyradius(atommasses, atomcoords, method="iupac"): msg = f"unavailable method: '{method}'" raise ValueError(msg) -def ignore_unhashable(func): - uncached = func.__wrapped__ - attributes = functools.WRAPPER_ASSIGNMENTS + ('cache_info', 'cache_clear') - @functools.wraps(func, assigned=attributes) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except TypeError as error: - if 'unhashable type' in str(error): - return uncached(*args, **kwargs) - raise - wrapper.__uncached__ = uncached - return wrapper -@ignore_unhashable -@functools.lru_cache() + +@rx._misc.ignore_unhashable +@cache() def inertia(atommasses, atomcoords, align=True): r"""Calculate primary moments and axes from the inertia tensor.