-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #588 from HEXRD/yaml-dumper
Add yaml dumper that converts numpy to native
- Loading branch information
Showing
8 changed files
with
84 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
import yaml | ||
|
||
|
||
class NumpyToNativeDumper(yaml.SafeDumper): | ||
"""Change Numpy types to native types during YAML encoding | ||
This inherits from yaml.SafeDumper so that anything that is not | ||
converted to a basic type will raise an error. | ||
For instance, np.float128 will raise an error, since it cannot be | ||
converted to a basic type. | ||
""" | ||
def represent_data(self, data): | ||
if isinstance(data, np.ndarray): | ||
return self.represent_list(data.tolist()) | ||
elif isinstance(data, (np.generic, np.number)): | ||
item = data.item() | ||
if isinstance(item, (np.generic, np.number)): | ||
# This means it was not converted successfully. | ||
# It is probably np.float128. | ||
msg = ( | ||
f'Failed to convert {item} with type {type(item)} to ' | ||
'a native type' | ||
) | ||
raise yaml.representer.RepresenterError(msg) | ||
|
||
return self.represent_data(item) | ||
|
||
return super().represent_data(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
import yaml | ||
|
||
from hexrd.utils.yaml import NumpyToNativeDumper | ||
|
||
|
||
def test_numpy_to_native(): | ||
to_test = { | ||
'inside': np.arange(27, dtype=np.int8).reshape((3, 3, 3)), | ||
'nested': { | ||
'float16': np.arange(4, dtype=np.float16).reshape((2, 2)), | ||
}, | ||
'float32': np.float32(32.5), | ||
'float64': np.float64(8.3), | ||
'int64': np.int64(3), | ||
'str': 'string', | ||
} | ||
|
||
encoded = yaml.dump(to_test, Dumper=NumpyToNativeDumper) | ||
output = yaml.safe_load(encoded) | ||
|
||
assert ( | ||
isinstance(output['inside'], list) and | ||
output['inside'] == to_test['inside'].tolist() | ||
) | ||
assert ( | ||
isinstance(output['nested']['float16'], list) and | ||
output['nested']['float16'] == to_test['nested']['float16'].tolist() | ||
) | ||
assert ( | ||
isinstance(output['float32'], float) and | ||
output['float32'] == to_test['float32'].item() | ||
) | ||
assert ( | ||
isinstance(output['float64'], float) and | ||
output['float64'] == to_test['float64'].item() | ||
) | ||
assert ( | ||
isinstance(output['int64'], int) and | ||
output['int64'] == to_test['int64'].item() | ||
) | ||
assert ( | ||
isinstance(output['str'], str) and | ||
output['str'] == to_test['str'] | ||
) |