Skip to content

Commit

Permalink
Add NumpyToNativeEncoder with a test
Browse files Browse the repository at this point in the history
NumpyToNativeEncoder is a custom JSON encoder that encodes numpy types
into native types so it can be saved in JSON.

There is no corresponding decoder, since the JSON just contains native
types and does not need special decoding.

Signed-off-by: Patrick Avery <[email protected]>
  • Loading branch information
psavery committed Nov 24, 2023
1 parent 44593cb commit 988e5cc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
12 changes: 11 additions & 1 deletion hexrd/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
ndarray_key = '!__hexrd_ndarray__'


class NumpyToNativeEncoder(json.JSONEncoder):
# Change all Numpy arrays to native types during JSON encoding
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.generic, np.number)):
return obj.item()

return super().default(obj)


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
Expand All @@ -20,7 +31,6 @@ def default(self, obj):
ndarray_key: base64.b64encode(data).decode('ascii')
}

# Let the base class default method raise the TypeError
return super().default(obj)


Expand Down
32 changes: 31 additions & 1 deletion tests/test_utils_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from hexrd.utils.json import NumpyDecoder, NumpyEncoder
from hexrd.utils.json import NumpyDecoder, NumpyEncoder, NumpyToNativeEncoder


def test_decode_encode():
Expand All @@ -29,6 +29,36 @@ def test_decode_encode():
assert _json_equal(to_test, converted_back)


def test_numpy_to_native():
to_test = {
'inside': np.arange(27, dtype=np.int8).reshape((3, 3, 3)),
'nested': {
'float': np.arange(4, dtype=np.float32).reshape((2, 2)),
},
'float': np.float64(8.3),
}

inside_to_list = to_test['inside'].tolist()
nested_float_to_value = to_test['nested']['float'].tolist()
float_to_value = to_test['float'].item()

encoded = json.dumps(to_test, cls=NumpyToNativeEncoder)
output = json.loads(encoded)

assert (
isinstance(output['inside'], list) and
output['inside'] == to_test['inside'].tolist()
)
assert (
isinstance(output['float'], float) and
output['float'] == to_test['float'].item()
)
assert (
isinstance(output['nested']['float'], list) and
output['nested']['float'] == to_test['nested']['float'].tolist()
)


def _json_equal(a, b):
if isinstance(a, np.ndarray):
return np.array_equal(a, b) and a.dtype == b.dtype
Expand Down

0 comments on commit 988e5cc

Please sign in to comment.