Skip to content

Commit

Permalink
move ignore_unhashable to _misc module + trying to resolve circular i…
Browse files Browse the repository at this point in the history
…mport
  • Loading branch information
m-rauen committed Dec 3, 2024
1 parent 6696fc2 commit ae6ab1c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
14 changes: 14 additions & 0 deletions overreact/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import contextlib
import functools
from functools import lru_cache as cache

import numpy as np
Expand All @@ -14,6 +15,19 @@
import overreact as rx
from overreact import _constants as constants

def ignore_unhashable(func):
uncached = func.__wrapped__
attributes = functools.WRAPPER_ASSIGNMENTS + ('cache_info', 'cache_clear')
@functools.wraps(func, assigned=attributes)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TypeError as error:
if 'unhashable type' in str(error):
return uncached(*args, **kwargs)
raise
wrapper.__uncached__ = uncached
return wrapper

def _find_package(package):
"""Check if a package exists without importing it.
Expand Down
20 changes: 4 additions & 16 deletions overreact/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# TODO(schneiderfelipe): add types to this module
from __future__ import annotations
import functools
from functools import lru_cache as cache

__all__ = ["find_point_group", "symmetry_number"]

Expand Down Expand Up @@ -1681,21 +1681,9 @@ def gyradius(atommasses, atomcoords, method="iupac"):
msg = f"unavailable method: '{method}'"
raise ValueError(msg)

def ignore_unhashable(func):
uncached = func.__wrapped__
attributes = functools.WRAPPER_ASSIGNMENTS + ('cache_info', 'cache_clear')
@functools.wraps(func, assigned=attributes)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TypeError as error:
if 'unhashable type' in str(error):
return uncached(*args, **kwargs)
raise
wrapper.__uncached__ = uncached
return wrapper
@ignore_unhashable
@functools.lru_cache()

@rx._misc.ignore_unhashable
@cache()
def inertia(atommasses, atomcoords, align=True):
r"""Calculate primary moments and axes from the inertia tensor.
Expand Down

0 comments on commit ae6ab1c

Please sign in to comment.