Skip to content

Commit

Permalink
Add json numpy encoder/decoder
Browse files Browse the repository at this point in the history
This also adds a quick test to verify that it works properly.

Signed-off-by: Patrick Avery <[email protected]>
  • Loading branch information
psavery committed Nov 23, 2023
1 parent a7d696e commit 1ade6aa
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
40 changes: 40 additions & 0 deletions hexrd/utils/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import base64
import io
import json

import numpy as np

ndarray_key = '!__hexrd_ndarray__'

class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
# Write it as an npy file
with io.BytesIO() as bytes_io:
np.save(bytes_io, obj, allow_pickle=False)
data = bytes_io.getvalue()

return {
# Need to base64 encode it so it is json-valid
ndarray_key: base64.b64encode(data).decode('ascii')
}

# Let the base class default method raise the TypeError
return super().default(obj)


class NumpyDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
kwargs = {
'object_hook': self.object_hook,
**kwargs,
}
super().__init__(*args, **kwargs)

def object_hook(self, obj):
if ndarray_key in obj:
data = base64.b64decode(obj[ndarray_key])
with io.BytesIO(data) as bytes_io:
return np.load(bytes_io)

return obj
50 changes: 50 additions & 0 deletions tests/test_utils_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json

import numpy as np

from hexrd.utils.json import NumpyDecoder, NumpyEncoder


def test_decode_encode():
to_test = [
{
'floating': np.arange(50, dtype=np.float16),
'complex': np.arange(20, dtype=np.complex128),
},
{
'nested': {
'int8': np.arange(27, dtype=np.int8).reshape((3, 3, 3)),
'uint8': np.arange(8, dtype=np.uint8).reshape((2, 4)),
'not_numpy': 3,
}
},
np.array([0, 5, 4]),
5,
'string',
]

output = json.dumps(to_test, cls=NumpyEncoder)
converted_back = json.loads(output, cls=NumpyDecoder)

assert _json_equal(to_test, converted_back)


def _json_equal(a, b):
if isinstance(a, np.ndarray):
return np.array_equal(a, b) and a.dtype == b.dtype

if isinstance(a, dict):
if list(a) != list(b):
return False

for k in a:
return _json_equal(a[k], b[k])

if isinstance(a, list):
if len(a) != len(b):
return False

for i in range(len(a)):
return _json_equal(a[i], b[i])

return a == b

0 comments on commit 1ade6aa

Please sign in to comment.