Skip to content

Commit

Permalink
Merge pull request #589 from donald-e-boyce/graindata_update
Browse files Browse the repository at this point in the history
fixed bug in GrainData.load and added rotation_matrices attribute
  • Loading branch information
donald-e-boyce authored Dec 7, 2023
2 parents de3149a + 9054350 commit f1c3bc4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
34 changes: 30 additions & 4 deletions hexrd/cli/fit_grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,28 @@ class GrainData(_BaseGrainData):
"""

def save(self, fname):
"""Save grain data to an np file"""
"""Save grain data to an np file
Parameters
----------
fname: path | string
name of the file to save to
"""
np.savez(fname, **self._asdict())

@classmethod
def load(cls, fname):
"""Return GrainData instance from npz file"""
return cls(np.load(fname))
"""Return GrainData instance from npz file
Parameters
----------
fname: path | string
name of the file to load
"""
return cls(**np.load(fname))

@classmethod
def from_array(cls, a):
"""Return GrainData instance from numpy array"""
"""Return GrainData instance from numpy datatype array"""
return cls(
id=a[:,0].astype(int),
completeness=a[:, 1],
Expand All @@ -55,6 +66,21 @@ def from_array(cls, a):
ln_Vs=a[:, 15:21],
)

@property
def rotation_matrices(self):
""""Return rotation matrices from exponential maps"""
#
# Compute the rotation matrices only once, the first time this is
# called, and save the results.
#
if not hasattr(self, "_rotation_matrices"):
n = len(self.expmap)
rmats = np.zeros((n, 3, 3))
for i in range(n):
rmats[i] = xfcapi.makeRotMatOfExpMap(self.expmap[i])
self._rotation_matrices = rmats
return self._rotation_matrices


def configure_parser(sub_parsers):
p = sub_parsers.add_parser('fit-grains', description=descr, help=descr)
Expand Down
47 changes: 47 additions & 0 deletions tests/test_graindata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Testing GrainData class"""
from pathlib import Path

import pytest
import numpy as np

from hexrd.cli.fit_grains import GrainData


save_file = "save.npz"

exp90 = (np.pi/2) * np.identity(3)
rmats90 = np.array([
[[1, 0, 0], [0, 0,- 1], [0, 1, 0]],
[[0, 0, 1], [0, 1, 0], [-1, 0, 0]],
[[0, -1, 0], [1, 0, 0], [0, 0, 1]]
])


@pytest.fixture
def graindata_0():
args = dict(
id=[0, 1, 2], completeness=[0, 0.5, 1.0], chisq=[2.1, 1.2, 0.1],
expmap=exp90, centroid=np.identity(3), inv_Vs=np.zeros(6),
ln_Vs=np.zeros(6),
)
return GrainData(**args)


def test_load_save(tmp_path, graindata_0):

gdata = graindata_0
save_path = tmp_path / save_file
gdata.save(save_path)
gdata_cmp = GrainData.load(save_path)

assert np.allclose(gdata.centroid, gdata_cmp.centroid)
assert np.allclose(gdata.completeness, gdata_cmp.completeness)
assert np.allclose(gdata.chisq, gdata_cmp.chisq)
assert np.allclose(gdata.inv_Vs, gdata_cmp.inv_Vs)
assert np.allclose(gdata.ln_Vs, gdata_cmp.ln_Vs)


def test_rotation_matrices(graindata_0):

gdata = graindata_0
assert np.allclose(gdata.rotation_matrices, rmats90)

0 comments on commit f1c3bc4

Please sign in to comment.