diff --git a/hexrd/cli/fit_grains.py b/hexrd/cli/fit_grains.py index 77896712f..e4527a9a6 100644 --- a/hexrd/cli/fit_grains.py +++ b/hexrd/cli/fit_grains.py @@ -34,17 +34,28 @@ class GrainData(_BaseGrainData): """ def save(self, fname): - """Save grain data to an np file""" + """Save grain data to an np file + + Parameters + ---------- + fname: path | string + name of the file to save to + """ np.savez(fname, **self._asdict()) @classmethod def load(cls, fname): - """Return GrainData instance from npz file""" - return cls(np.load(fname)) + """Return GrainData instance from npz file + Parameters + ---------- + fname: path | string + name of the file to load + """ + return cls(**np.load(fname)) @classmethod def from_array(cls, a): - """Return GrainData instance from numpy array""" + """Return GrainData instance from numpy datatype array""" return cls( id=a[:,0].astype(int), completeness=a[:, 1], @@ -55,6 +66,21 @@ def from_array(cls, a): ln_Vs=a[:, 15:21], ) + @property + def rotation_matrices(self): + """"Return rotation matrices from exponential maps""" + # + # Compute the rotation matrices only once, the first time this is + # called, and save the results. + # + if not hasattr(self, "_rotation_matrices"): + n = len(self.expmap) + rmats = np.zeros((n, 3, 3)) + for i in range(n): + rmats[i] = xfcapi.makeRotMatOfExpMap(self.expmap[i]) + self._rotation_matrices = rmats + return self._rotation_matrices + def configure_parser(sub_parsers): p = sub_parsers.add_parser('fit-grains', description=descr, help=descr) diff --git a/tests/test_graindata.py b/tests/test_graindata.py new file mode 100644 index 000000000..3f6001e35 --- /dev/null +++ b/tests/test_graindata.py @@ -0,0 +1,47 @@ +"""Testing GrainData class""" +from pathlib import Path + +import pytest +import numpy as np + +from hexrd.cli.fit_grains import GrainData + + +save_file = "save.npz" + +exp90 = (np.pi/2) * np.identity(3) +rmats90 = np.array([ + [[1, 0, 0], [0, 0,- 1], [0, 1, 0]], + [[0, 0, 1], [0, 1, 0], [-1, 0, 0]], + [[0, -1, 0], [1, 0, 0], [0, 0, 1]] +]) + + +@pytest.fixture +def graindata_0(): + args = dict( + id=[0, 1, 2], completeness=[0, 0.5, 1.0], chisq=[2.1, 1.2, 0.1], + expmap=exp90, centroid=np.identity(3), inv_Vs=np.zeros(6), + ln_Vs=np.zeros(6), + ) + return GrainData(**args) + + +def test_load_save(tmp_path, graindata_0): + + gdata = graindata_0 + save_path = tmp_path / save_file + gdata.save(save_path) + gdata_cmp = GrainData.load(save_path) + + assert np.allclose(gdata.centroid, gdata_cmp.centroid) + assert np.allclose(gdata.completeness, gdata_cmp.completeness) + assert np.allclose(gdata.chisq, gdata_cmp.chisq) + assert np.allclose(gdata.inv_Vs, gdata_cmp.inv_Vs) + assert np.allclose(gdata.ln_Vs, gdata_cmp.ln_Vs) + + +def test_rotation_matrices(graindata_0): + + gdata = graindata_0 + assert np.allclose(gdata.rotation_matrices, rmats90)