Skip to content

Commit

Permalink
drop import guard for numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Oct 20, 2024
1 parent 78f16e5 commit fe58998
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 32 deletions.
5 changes: 1 addition & 4 deletions src/monty/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import itertools
from typing import TYPE_CHECKING

try:
import numpy as np
except ImportError:
np = None
import numpy as np

if TYPE_CHECKING:
from typing import Iterable
Expand Down
38 changes: 18 additions & 20 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,9 @@
from typing import Any
from uuid import UUID, uuid4

import numpy as np
from ruamel.yaml import YAML

try:
import numpy as np
except ImportError:
np = None

try:
import pydantic
except ImportError:
Expand Down Expand Up @@ -595,23 +591,22 @@ def default(self, o) -> dict:
d["data"] = o.numpy().tolist()
return d

if np is not None:
if isinstance(o, np.ndarray):
if str(o.dtype).startswith("complex"):
return {
"@module": "numpy",
"@class": "array",
"dtype": str(o.dtype),
"data": [o.real.tolist(), o.imag.tolist()],
}
if isinstance(o, np.ndarray):
if str(o.dtype).startswith("complex"):
return {
"@module": "numpy",
"@class": "array",
"dtype": str(o.dtype),
"data": o.tolist(),
"data": [o.real.tolist(), o.imag.tolist()],
}
if isinstance(o, np.generic):
return o.item()
return {
"@module": "numpy",
"@class": "array",
"dtype": str(o.dtype),
"data": o.tolist(),
}
if isinstance(o, np.generic):
return o.item()

if _check_type(o, "pandas.core.frame.DataFrame"):
return {
Expand Down Expand Up @@ -809,7 +804,7 @@ def process_decoded(self, d):
).type(d["dtype"])
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101

elif np is not None and modname == "numpy" and classname == "array":
elif modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
return np.array(
[
Expand Down Expand Up @@ -932,7 +927,8 @@ def jsanitize(
)
for i in obj
]
if np is not None and isinstance(obj, np.ndarray):

if isinstance(obj, np.ndarray):
try:
return [
jsanitize(
Expand All @@ -946,8 +942,10 @@ def jsanitize(
]
except TypeError:
return obj.tolist()
if np is not None and isinstance(obj, np.generic):

if isinstance(obj, np.generic):
return obj.item()

if _check_type(
obj,
(
Expand Down
10 changes: 2 additions & 8 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from enum import Enum
from typing import Union

try:
import numpy as np
except ImportError:
np = None
import numpy as np

try:
import pandas as pd
Expand Down Expand Up @@ -564,7 +561,6 @@ def test_nan(self):
d = json.loads(djson)
assert isinstance(d[0], float)

@pytest.mark.skipif(np is None, reason="numpy not present")
def test_numpy(self):
x = np.array([1, 2, 3], dtype="int64")
with pytest.raises(TypeError):
Expand Down Expand Up @@ -872,9 +868,7 @@ def test_jsanitize_pandas(self):
clean = jsanitize(s)
assert clean == s.to_dict()

@pytest.mark.skipif(
np is None or ObjectId is None, reason="numpy and bson not present"
)
@pytest.mark.skipif(ObjectId is None, reason="bson not present")
def test_jsanitize_numpy_bson(self):
d = {
"a": ["b", np.array([1, 2, 3])],
Expand Down

0 comments on commit fe58998

Please sign in to comment.