Skip to content

Commit

Permalink
Merge pull request #588 from HEXRD/yaml-dumper
Browse files Browse the repository at this point in the history
Add yaml dumper that converts numpy to native
  • Loading branch information
psavery authored Nov 29, 2023
2 parents 1dc509b + 674aba7 commit ff96462
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 8 deletions.
4 changes: 2 additions & 2 deletions hexrd/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def save(config_list, file_name):

with open_file(file_name, 'w') as f:
if len(res) > 1:
yaml.dump_all(res, f)
yaml.safe_dump_all(res, f)
else:
yaml.dump(res, f)
yaml.safe_dump(res, f)
2 changes: 1 addition & 1 deletion hexrd/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def dump(self, filename):
import yaml

with open(filename, 'w') as f:
yaml.dump(self._cfg, f)
yaml.safe_dump(self._cfg, f)
self._dirty = False

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion hexrd/imageseries/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _write_yml(self):
'nframes': len(self._ims), 'shape': list(self._ims.shape)}
info = {'data': datad, 'meta': self._process_meta(save_omegas=True)}
with open(self._fname, "w") as f:
yaml.dump(info, f)
yaml.safe_dump(info, f)

def _write_frames(self):
"""also save shape array as originally done (before yaml)"""
Expand Down
3 changes: 2 additions & 1 deletion hexrd/instrument/hedm_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from hexrd import distortion as distortion_pkg
from hexrd.utils.compatibility import h5py_read_string
from hexrd.utils.concurrent import distribute_tasks
from hexrd.utils.yaml import NumpyToNativeDumper
from hexrd.valunits import valWUnit
from hexrd.wppf import LeBail

Expand Down Expand Up @@ -1008,7 +1009,7 @@ def write_config(self, file=None, style='yaml', calibration_dict={}):
if file is not None:
if style.lower() == 'yaml':
with open(file, 'w') as f:
yaml.dump(par_dict, stream=f)
yaml.dump(par_dict, stream=f, Dumper=NumpyToNativeDumper)
else:
def _write_group(file):
instr_grp = file.create_group('instrument')
Expand Down
30 changes: 30 additions & 0 deletions hexrd/utils/yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import yaml


class NumpyToNativeDumper(yaml.SafeDumper):
"""Change Numpy types to native types during YAML encoding
This inherits from yaml.SafeDumper so that anything that is not
converted to a basic type will raise an error.
For instance, np.float128 will raise an error, since it cannot be
converted to a basic type.
"""
def represent_data(self, data):
if isinstance(data, np.ndarray):
return self.represent_list(data.tolist())
elif isinstance(data, (np.generic, np.number)):
item = data.item()
if isinstance(item, (np.generic, np.number)):
# This means it was not converted successfully.
# It is probably np.float128.
msg = (
f'Failed to convert {item} with type {type(item)} to '
'a native type'
)
raise yaml.representer.RepresenterError(msg)

return self.represent_data(item)

return super().represent_data(data)
2 changes: 1 addition & 1 deletion hexrd/wppf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def dump(self, fname):
dic[k] = [self[k].value, self[k].lb, self[k].ub, self[k].vary]

with open(fname, 'w') as f:
data = yaml.dump(dic, f, sort_keys=False)
data = yaml.safe_dump(dic, f, sort_keys=False)

def dump_hdf5(self, file):
"""
Expand Down
4 changes: 2 additions & 2 deletions hexrd/wppf/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def dump(self, fname):
dic[k] = [m for m in self]

with open(fname, 'w') as f:
data = yaml.dump(dic, f, sort_keys=False)
data = yaml.safe_dump(dic, f, sort_keys=False)

def dump_hdf5(self, file):
"""
Expand Down Expand Up @@ -1430,7 +1430,7 @@ def dump(self, fname):
dic[k] = [m for m in self]

with open(fname, 'w') as f:
data = yaml.dump(dic, f, sort_keys=False)
data = yaml.safe_dump(dic, f, sort_keys=False)

@property
def phase_fraction(self):
Expand Down
45 changes: 45 additions & 0 deletions tests/test_utils_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import yaml

from hexrd.utils.yaml import NumpyToNativeDumper


def test_numpy_to_native():
to_test = {
'inside': np.arange(27, dtype=np.int8).reshape((3, 3, 3)),
'nested': {
'float16': np.arange(4, dtype=np.float16).reshape((2, 2)),
},
'float32': np.float32(32.5),
'float64': np.float64(8.3),
'int64': np.int64(3),
'str': 'string',
}

encoded = yaml.dump(to_test, Dumper=NumpyToNativeDumper)
output = yaml.safe_load(encoded)

assert (
isinstance(output['inside'], list) and
output['inside'] == to_test['inside'].tolist()
)
assert (
isinstance(output['nested']['float16'], list) and
output['nested']['float16'] == to_test['nested']['float16'].tolist()
)
assert (
isinstance(output['float32'], float) and
output['float32'] == to_test['float32'].item()
)
assert (
isinstance(output['float64'], float) and
output['float64'] == to_test['float64'].item()
)
assert (
isinstance(output['int64'], int) and
output['int64'] == to_test['int64'].item()
)
assert (
isinstance(output['str'], str) and
output['str'] == to_test['str']
)

0 comments on commit ff96462

Please sign in to comment.