diff --git a/hexrd/config/__init__.py b/hexrd/config/__init__.py index 940c8cf37..075f33989 100644 --- a/hexrd/config/__init__.py +++ b/hexrd/config/__init__.py @@ -45,6 +45,6 @@ def save(config_list, file_name): with open_file(file_name, 'w') as f: if len(res) > 1: - yaml.dump_all(res, f) + yaml.safe_dump_all(res, f) else: - yaml.dump(res, f) + yaml.safe_dump(res, f) diff --git a/hexrd/config/config.py b/hexrd/config/config.py index 74c620629..5e7aeaf47 100644 --- a/hexrd/config/config.py +++ b/hexrd/config/config.py @@ -56,7 +56,7 @@ def dump(self, filename): import yaml with open(filename, 'w') as f: - yaml.dump(self._cfg, f) + yaml.safe_dump(self._cfg, f) self._dirty = False @staticmethod diff --git a/hexrd/imageseries/save.py b/hexrd/imageseries/save.py index 79ec909dc..106c68b4b 100644 --- a/hexrd/imageseries/save.py +++ b/hexrd/imageseries/save.py @@ -212,7 +212,7 @@ def _write_yml(self): 'nframes': len(self._ims), 'shape': list(self._ims.shape)} info = {'data': datad, 'meta': self._process_meta(save_omegas=True)} with open(self._fname, "w") as f: - yaml.dump(info, f) + yaml.safe_dump(info, f) def _write_frames(self): """also save shape array as originally done (before yaml)""" diff --git a/hexrd/instrument/hedm_instrument.py b/hexrd/instrument/hedm_instrument.py index 58bd955d1..889591549 100644 --- a/hexrd/instrument/hedm_instrument.py +++ b/hexrd/instrument/hedm_instrument.py @@ -81,6 +81,7 @@ from hexrd import distortion as distortion_pkg from hexrd.utils.compatibility import h5py_read_string from hexrd.utils.concurrent import distribute_tasks +from hexrd.utils.yaml import NumpyToNativeDumper from hexrd.valunits import valWUnit from hexrd.wppf import LeBail @@ -1008,7 +1009,7 @@ def write_config(self, file=None, style='yaml', calibration_dict={}): if file is not None: if style.lower() == 'yaml': with open(file, 'w') as f: - yaml.dump(par_dict, stream=f) + yaml.dump(par_dict, stream=f, Dumper=NumpyToNativeDumper) else: def _write_group(file): instr_grp = file.create_group('instrument') diff --git a/hexrd/utils/yaml.py b/hexrd/utils/yaml.py new file mode 100644 index 000000000..88555ec41 --- /dev/null +++ b/hexrd/utils/yaml.py @@ -0,0 +1,30 @@ +import numpy as np +import yaml + + +class NumpyToNativeDumper(yaml.SafeDumper): + """Change Numpy types to native types during YAML encoding + + This inherits from yaml.SafeDumper so that anything that is not + converted to a basic type will raise an error. + + For instance, np.float128 will raise an error, since it cannot be + converted to a basic type. + """ + def represent_data(self, data): + if isinstance(data, np.ndarray): + return self.represent_list(data.tolist()) + elif isinstance(data, (np.generic, np.number)): + item = data.item() + if isinstance(item, (np.generic, np.number)): + # This means it was not converted successfully. + # It is probably np.float128. + msg = ( + f'Failed to convert {item} with type {type(item)} to ' + 'a native type' + ) + raise yaml.representer.RepresenterError(msg) + + return self.represent_data(item) + + return super().represent_data(data) diff --git a/hexrd/wppf/parameters.py b/hexrd/wppf/parameters.py index 041591c8c..e8b76c996 100644 --- a/hexrd/wppf/parameters.py +++ b/hexrd/wppf/parameters.py @@ -93,7 +93,7 @@ def dump(self, fname): dic[k] = [self[k].value, self[k].lb, self[k].ub, self[k].vary] with open(fname, 'w') as f: - data = yaml.dump(dic, f, sort_keys=False) + data = yaml.safe_dump(dic, f, sort_keys=False) def dump_hdf5(self, file): """ diff --git a/hexrd/wppf/phase.py b/hexrd/wppf/phase.py index 785480d92..8ef597208 100644 --- a/hexrd/wppf/phase.py +++ b/hexrd/wppf/phase.py @@ -542,7 +542,7 @@ def dump(self, fname): dic[k] = [m for m in self] with open(fname, 'w') as f: - data = yaml.dump(dic, f, sort_keys=False) + data = yaml.safe_dump(dic, f, sort_keys=False) def dump_hdf5(self, file): """ @@ -1430,7 +1430,7 @@ def dump(self, fname): dic[k] = [m for m in self] with open(fname, 'w') as f: - data = yaml.dump(dic, f, sort_keys=False) + data = yaml.safe_dump(dic, f, sort_keys=False) @property def phase_fraction(self): diff --git a/tests/test_utils_yaml.py b/tests/test_utils_yaml.py new file mode 100644 index 000000000..cc8ff1c1f --- /dev/null +++ b/tests/test_utils_yaml.py @@ -0,0 +1,45 @@ +import numpy as np +import yaml + +from hexrd.utils.yaml import NumpyToNativeDumper + + +def test_numpy_to_native(): + to_test = { + 'inside': np.arange(27, dtype=np.int8).reshape((3, 3, 3)), + 'nested': { + 'float16': np.arange(4, dtype=np.float16).reshape((2, 2)), + }, + 'float32': np.float32(32.5), + 'float64': np.float64(8.3), + 'int64': np.int64(3), + 'str': 'string', + } + + encoded = yaml.dump(to_test, Dumper=NumpyToNativeDumper) + output = yaml.safe_load(encoded) + + assert ( + isinstance(output['inside'], list) and + output['inside'] == to_test['inside'].tolist() + ) + assert ( + isinstance(output['nested']['float16'], list) and + output['nested']['float16'] == to_test['nested']['float16'].tolist() + ) + assert ( + isinstance(output['float32'], float) and + output['float32'] == to_test['float32'].item() + ) + assert ( + isinstance(output['float64'], float) and + output['float64'] == to_test['float64'].item() + ) + assert ( + isinstance(output['int64'], int) and + output['int64'] == to_test['int64'].item() + ) + assert ( + isinstance(output['str'], str) and + output['str'] == to_test['str'] + )