diff --git a/hexrd/utils/json.py b/hexrd/utils/json.py index 88941a472..46a3e9ff9 100644 --- a/hexrd/utils/json.py +++ b/hexrd/utils/json.py @@ -7,6 +7,17 @@ ndarray_key = '!__hexrd_ndarray__' +class NumpyToNativeEncoder(json.JSONEncoder): + # Change all Numpy arrays to native types during JSON encoding + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, (np.generic, np.number)): + return obj.item() + + return super().default(obj) + + class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): @@ -20,7 +31,6 @@ def default(self, obj): ndarray_key: base64.b64encode(data).decode('ascii') } - # Let the base class default method raise the TypeError return super().default(obj) diff --git a/tests/test_utils_json.py b/tests/test_utils_json.py index 1bd557565..3824c9640 100644 --- a/tests/test_utils_json.py +++ b/tests/test_utils_json.py @@ -2,7 +2,7 @@ import numpy as np -from hexrd.utils.json import NumpyDecoder, NumpyEncoder +from hexrd.utils.json import NumpyDecoder, NumpyEncoder, NumpyToNativeEncoder def test_decode_encode(): @@ -29,6 +29,36 @@ def test_decode_encode(): assert _json_equal(to_test, converted_back) +def test_numpy_to_native(): + to_test = { + 'inside': np.arange(27, dtype=np.int8).reshape((3, 3, 3)), + 'nested': { + 'float': np.arange(4, dtype=np.float32).reshape((2, 2)), + }, + 'float': np.float64(8.3), + } + + inside_to_list = to_test['inside'].tolist() + nested_float_to_value = to_test['nested']['float'].tolist() + float_to_value = to_test['float'].item() + + encoded = json.dumps(to_test, cls=NumpyToNativeEncoder) + output = json.loads(encoded) + + assert ( + isinstance(output['inside'], list) and + output['inside'] == to_test['inside'].tolist() + ) + assert ( + isinstance(output['float'], float) and + output['float'] == to_test['float'].item() + ) + assert ( + isinstance(output['nested']['float'], list) and + output['nested']['float'] == to_test['nested']['float'].tolist() + ) + + def _json_equal(a, b): if isinstance(a, np.ndarray): return np.array_equal(a, b) and a.dtype == b.dtype