diff --git a/hexrd/fitting/calibration/__init__.py b/hexrd/fitting/calibration/__init__.py index 77aa739b..c14026f3 100644 --- a/hexrd/fitting/calibration/__init__.py +++ b/hexrd/fitting/calibration/__init__.py @@ -1,6 +1,5 @@ from .instrument import InstrumentCalibrator from .laue import LaueCalibrator -from .lmfit_param_handling import RelativeConstraints from .multigrain import calibrate_instrument_from_sx, generate_parameter_names from .powder import PowderCalibrator from .structureless import StructurelessCalibrator @@ -14,7 +13,6 @@ 'InstrumentCalibrator', 'LaueCalibrator', 'PowderCalibrator', - 'RelativeConstraints', 'StructurelessCalibrator', 'StructureLessCalibrator', ] diff --git a/hexrd/fitting/calibration/instrument.py b/hexrd/fitting/calibration/instrument.py index 8efec3ac..06c52178 100644 --- a/hexrd/fitting/calibration/instrument.py +++ b/hexrd/fitting/calibration/instrument.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import lmfit import numpy as np @@ -9,7 +10,11 @@ DEFAULT_EULER_CONVENTION, update_instrument_from_params, validate_params_list, +) +from .relative_constraints import ( + create_relative_constraints, RelativeConstraints, + RelativeConstraintsType, ) logger = logging.getLogger() @@ -24,7 +29,7 @@ class InstrumentCalibrator: def __init__(self, *args, engineering_constraints=None, set_refinements_from_instrument_flags=True, euler_convention=DEFAULT_EULER_CONVENTION, - relative_constraints=RelativeConstraints.none): + relative_constraints_type=RelativeConstraintsType.none): """ Model for instrument calibration class as a function of @@ -47,7 +52,8 @@ def __init__(self, *args, engineering_constraints=None, assert calib.instr is self.instr, \ "all calibrators must refer to the same instrument" self._engineering_constraints = engineering_constraints - self._relative_constraints = relative_constraints + self._relative_constraints = create_relative_constraints( + relative_constraints_type, self.instr) self.euler_convention = euler_convention self.params = self.make_lmfit_params() @@ -164,18 +170,32 @@ def engineering_constraints(self, v): self._engineering_constraints = v self.params = self.make_lmfit_params() + @property + def relative_constraints_type(self): + return self._relative_constraints.type + + @relative_constraints_type.setter + def relative_constraints_type(self, v: Optional[RelativeConstraintsType]): + v = v if v is not None else RelativeConstraintsType.none + + current = getattr(self, '_relative_constraints', None) + if current is None or current.type != v: + self.relative_constraints = create_relative_constraints( + v, self.instr) + @property def relative_constraints(self) -> RelativeConstraints: return self._relative_constraints @relative_constraints.setter def relative_constraints(self, v: RelativeConstraints): - if v == self._relative_constraints: - return - self._relative_constraints = v self.params = self.make_lmfit_params() + def reset_relative_constraint_params(self): + # Set them back to zero. + self.relative_constraints.reset() + def run_calibration(self, odict): resd0 = self.residual() nrm_ssr_0 = _normalized_ssqr(resd0) diff --git a/hexrd/fitting/calibration/lmfit_param_handling.py b/hexrd/fitting/calibration/lmfit_param_handling.py index 5ea113c9..1a667f52 100644 --- a/hexrd/fitting/calibration/lmfit_param_handling.py +++ b/hexrd/fitting/calibration/lmfit_param_handling.py @@ -1,4 +1,4 @@ -from enum import Enum +from typing import Optional import lmfit import numpy as np @@ -17,24 +17,18 @@ rotMatOfExpMap, ) from hexrd.material.unitcell import _lpname +from .relative_constraints import ( + RelativeConstraints, + RelativeConstraintsType, +) # First is the axes_order, second is extrinsic DEFAULT_EULER_CONVENTION = ('zxz', False) -class RelativeConstraints(Enum): - """These are relative constraints between the detectors""" - # 'none' means no relative constraints - none = 'None' - # 'group' means constrain tilts/translations within a group - group = 'Group' - # 'system' means constrain tilts/translations within the whole system - system = 'System' - - def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION, - relative_constraints=RelativeConstraints.none): + relative_constraints=None): # add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP) parms_list = [] @@ -62,23 +56,27 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION, parms_list.append(('instr_tvec_y', instr.tvec[1], False, -np.inf, np.inf)) parms_list.append(('instr_tvec_z', instr.tvec[2], False, -np.inf, np.inf)) - if relative_constraints == RelativeConstraints.none: + if ( + relative_constraints is None or + relative_constraints.type == RelativeConstraintsType.none + ): add_unconstrained_detector_parameters( instr, euler_convention, parms_list, ) - elif relative_constraints == RelativeConstraints.group: + elif relative_constraints.type == RelativeConstraintsType.group: # This should be implemented soon - raise NotImplementedError(relative_constraints) - elif relative_constraints == RelativeConstraints.system: + raise NotImplementedError(relative_constraints.type) + elif relative_constraints.type == RelativeConstraintsType.system: add_system_constrained_detector_parameters( instr, euler_convention, parms_list, + relative_constraints, ) else: - raise NotImplementedError(relative_constraints) + raise NotImplementedError(relative_constraints.type) return parms_list @@ -122,10 +120,24 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list): -np.inf, np.inf)) -def add_system_constrained_detector_parameters(instr, euler_convention, - parms_list): - mean_center = instr.mean_detector_center - mean_tilt = instr.mean_detector_tilt +def add_system_constrained_detector_parameters( + instr, euler_convention, + parms_list, relative_constraints: RelativeConstraints): + system_params = relative_constraints.params + system_tvec = system_params['translation'] + system_tilt = system_params['tilt'] + + if euler_convention is not None: + # Convert the tilt to the specified Euler convention + normalized = normalize_euler_convention(euler_convention) + rme = RotMatEuler( + np.zeros(3,), + axes_order=normalized[0], + extrinsic=normalized[1], + ) + + rme.rmat = _tilt_to_rmat(system_tilt, None) + system_tilt = np.degrees(rme.angles) tvec_names = [ 'system_tvec_x', @@ -138,12 +150,12 @@ def add_system_constrained_detector_parameters(instr, euler_convention, tilt_deltas = [2, 2, 2] for i, name in enumerate(tvec_names): - value = mean_center[i] + value = system_tvec[i] delta = tvec_deltas[i] parms_list.append((name, value, True, value - delta, value + delta)) for i, name in enumerate(tilt_names): - value = mean_tilt[i] + value = system_tilt[i] delta = tilt_deltas[i] parms_list.append((name, value, True, value - delta, value + delta)) @@ -160,8 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]: return param_names -def update_instrument_from_params(instr, params, euler_convention, - relative_constraints): +def update_instrument_from_params( + instr, params, + euler_convention=DEFAULT_EULER_CONVENTION, + relative_constraints: Optional[RelativeConstraints] = None): """ this function updates the instrument from the lmfit parameter list. we don't have to keep track @@ -196,23 +210,27 @@ def update_instrument_from_params(instr, params, euler_convention, params['instr_tvec_z'].value] instr.tvec = np.r_[instr_tvec] - if relative_constraints == RelativeConstraints.none: + if ( + relative_constraints is None or + relative_constraints.type == RelativeConstraintsType.none + ): update_unconstrained_detector_parameters( instr, params, euler_convention, ) - elif relative_constraints == RelativeConstraints.group: + elif relative_constraints.type == RelativeConstraintsType.group: # This should be implemented soon - raise NotImplementedError(relative_constraints) - elif relative_constraints == RelativeConstraints.system: + raise NotImplementedError(relative_constraints.type) + elif relative_constraints.type == RelativeConstraintsType.system: update_system_constrained_detector_parameters( instr, params, euler_convention, + relative_constraints, ) else: - raise NotImplementedError(relative_constraints) + raise NotImplementedError(relative_constraints.type) def update_unconstrained_detector_parameters(instr, params, euler_convention): @@ -245,10 +263,15 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention): ) -def update_system_constrained_detector_parameters(instr, params, euler_convention): - # We will always rotate/translate about the center of the group +def update_system_constrained_detector_parameters( + instr, params, euler_convention, + relative_constraints: RelativeConstraints): + # We will always rotate about the center of the detectors mean_center = instr.mean_detector_center - mean_tilt = instr.mean_detector_tilt + + system_params = relative_constraints.params + system_tvec = system_params['translation'] + system_tilt = system_params['tilt'] tvec_names = [ 'system_tvec_x', @@ -263,11 +286,11 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio if any(params[x].vary for x in tilt_names): # Find the change in tilt, create an rmat, then apply to detector tilts # and translations. - new_mean_tilt = np.array([params[x].value for x in tilt_names]) + new_system_tilt = np.array([params[x].value for x in tilt_names]) - # The old mean tilt was in the None convention - old_rmat = _tilt_to_rmat(mean_tilt, None) - new_rmat = _tilt_to_rmat(new_mean_tilt, euler_convention) + # The old system tilt was in the None convention + old_rmat = _tilt_to_rmat(system_tilt, None) + new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention) # Compute the rmat used to convert from old to new rmat_diff = new_rmat @ old_rmat.T @@ -276,19 +299,26 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio for panel in instr.detectors.values(): panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat) - # Also rotate the detectors about the center + # Also rotate the detectors about the mean center panel.tvec = rmat_diff @ (panel.tvec - mean_center) + mean_center + # Update the system tilt + system_tilt[:] = _rmat_to_tilt(new_rmat) + if any(params[x].vary for x in tvec_names): # Find the change in center and shift all tvecs - new_mean_center = np.array([params[x].value for x in tvec_names]) + new_system_tvec = np.array([params[x].value for x in tvec_names]) - diff = new_mean_center - mean_center + diff = new_system_tvec - system_tvec for panel in instr.detectors.values(): panel.tvec += diff + # Update the system tvec + system_tvec[:] = new_system_tvec + -def _tilt_to_rmat(tilt: np.ndarray, euler_convention: dict | tuple) -> np.ndarray: +def _tilt_to_rmat(tilt: np.ndarray, + euler_convention: dict | tuple) -> np.ndarray: # Convert the tilt to exponential map parameters, and then # to the rotation matrix, and return. if euler_convention is None: diff --git a/hexrd/fitting/calibration/relative_constraints.py b/hexrd/fitting/calibration/relative_constraints.py new file mode 100644 index 00000000..524d14d0 --- /dev/null +++ b/hexrd/fitting/calibration/relative_constraints.py @@ -0,0 +1,101 @@ +from abc import ABC, abstractmethod +from enum import Enum + +import numpy as np + +from hexrd.instrument import HEDMInstrument + + +class RelativeConstraintsType(Enum): + """These are relative constraints between the detectors""" + # 'none' means no relative constraints + none = 'None' + # 'group' means constrain tilts/translations within a group + group = 'Group' + # 'system' means constrain tilts/translations within the whole system + system = 'System' + + +class RelativeConstraints(ABC): + @property + @abstractmethod + def type(self) -> RelativeConstraintsType: + pass + + @property + @abstractmethod + def params(self) -> dict: + pass + + @abstractmethod + def reset(self): + # Reset the parameters + pass + + +class RelativeConstraintsNone(RelativeConstraints): + type = RelativeConstraintsType.none + + @property + def params(self) -> dict: + return {} + + def reset(self): + pass + + +class RelativeConstraintsGroup(RelativeConstraints): + type = RelativeConstraintsType.group + + def __init__(self, instr: HEDMInstrument): + self._groups = [] + for panel in instr.detectors.values(): + if panel.group is not None and panel.group not in self._groups: + self._groups.append(panel.group) + + self.reset() + + def reset(self): + self.group_params = {} + + for group in self._groups: + self.group_params[group] = { + 'tilt': np.array([0, 0, 0], dtype=float), + 'translation': np.array([0, 0, 0], dtype=float), + } + + @property + def params(self) -> dict: + return self.group_params + + +class RelativeConstraintsSystem(RelativeConstraints): + type = RelativeConstraintsType.system + + def __init__(self): + self.reset() + + @property + def params(self) -> dict: + return self._params + + def reset(self): + self._params = { + 'tilt': np.array([0, 0, 0], dtype=float), + 'translation': np.array([0, 0, 0], dtype=float), + } + + +def create_relative_constraints(type: RelativeConstraintsType, + instr: HEDMInstrument): + types = { + 'None': RelativeConstraintsNone, + 'Group': RelativeConstraintsGroup, + 'System': RelativeConstraintsSystem, + } + + kwargs = {} + if type == 'System': + kwargs['instr'] = instr + + return types[type.value](**kwargs) diff --git a/hexrd/fitting/calibration/structureless.py b/hexrd/fitting/calibration/structureless.py index c82837b3..4bc74359 100644 --- a/hexrd/fitting/calibration/structureless.py +++ b/hexrd/fitting/calibration/structureless.py @@ -1,4 +1,6 @@ import copy +from typing import Optional + import lmfit import numpy as np @@ -9,10 +11,14 @@ create_instr_params, create_tth_parameters, DEFAULT_EULER_CONVENTION, - RelativeConstraints, tth_parameter_prefixes, update_instrument_from_params, ) +from .relative_constraints import ( + create_relative_constraints, + RelativeConstraints, + RelativeConstraintsType, +) class StructurelessCalibrator: @@ -39,14 +45,15 @@ def __init__(self, data, tth_distortion=None, engineering_constraints=None, - relative_constraints=RelativeConstraints.none, + relative_constraints_type=RelativeConstraintsType.none, euler_convention=DEFAULT_EULER_CONVENTION): self._instr = instr self._data = data self._tth_distortion = tth_distortion self._engineering_constraints = engineering_constraints - self._relative_constraints = relative_constraints + self._relative_constraints = create_relative_constraints( + relative_constraints_type, self.instr) self.euler_convention = euler_convention self._update_tth_distortion_panels() self.make_lmfit_params() @@ -163,16 +170,26 @@ def _update_tth_distortion_panels(self): obj.panel = self.instr.detectors[det_key] @property - def relative_constraints(self): + def relative_constraints_type(self): + return self._relative_constraints.type + + @relative_constraints_type.setter + def relative_constraints_type(self, v: Optional[RelativeConstraintsType]): + v = v if v is not None else RelativeConstraintsType.none + + current = getattr(self, '_relative_constraints', None) + if current is None or current.type != v: + self.relative_constraints = create_relative_constraints( + v, self.instr) + + @property + def relative_constraints(self) -> RelativeConstraints: return self._relative_constraints @relative_constraints.setter - def relative_constraints(self, v): - if v == self._relative_constraints: - return - + def relative_constraints(self, v: RelativeConstraints): self._relative_constraints = v - self.make_lmfit_params() + self.params = self.make_lmfit_params() @property def engineering_constraints(self): diff --git a/hexrd/instrument/hedm_instrument.py b/hexrd/instrument/hedm_instrument.py index 53cc64bc..5d38e1a8 100644 --- a/hexrd/instrument/hedm_instrument.py +++ b/hexrd/instrument/hedm_instrument.py @@ -813,28 +813,6 @@ def mean_group_centers(self) -> dict[str, np.ndarray]: return {k: v.sum(axis=0) / len(v) for k, v in centers.items()} - @property - def mean_detector_tilt(self) -> np.ndarray: - """Return the mean tilt for all detectors""" - tilts = np.array([panel.tilt for panel in self.detectors.values()]) - return tilts.sum(axis=0) / len(tilts) - - @property - def mean_group_tilts(self) -> dict[str, np.ndarray]: - """Return the mean tilt for every group of detectors""" - tilts = {} - for panel in self.detectors.values(): - if panel.group is None: - # Skip over panels without groups - continue - - if panel.group not in tilts: - tilts[panel.group] = [] - - tilts[panel.group].append(panel.tilt) - - return {k: v.sum(axis=0) / len(v) for k, v in tilts.items()} - # properties for physical size of rectangular detector @property def id(self): diff --git a/tests/test_2xrs_calibration.py b/tests/calibration/test_2xrs_calibration.py similarity index 100% rename from tests/test_2xrs_calibration.py rename to tests/calibration/test_2xrs_calibration.py diff --git a/tests/test_calibration.py b/tests/calibration/test_calibration.py similarity index 100% rename from tests/test_calibration.py rename to tests/calibration/test_calibration.py diff --git a/tests/calibration/test_relative_constraints.py b/tests/calibration/test_relative_constraints.py index 2cf6b71d..6b331e82 100644 --- a/tests/calibration/test_relative_constraints.py +++ b/tests/calibration/test_relative_constraints.py @@ -9,7 +9,9 @@ from hexrd.fitting.calibration import ( InstrumentCalibrator, PowderCalibrator, - RelativeConstraints, +) +from hexrd.fitting.calibration.relative_constraints import ( + RelativeConstraintsType, ) from hexrd.imageseries.process import ProcessedImageSeries from hexrd.instrument import HEDMInstrument @@ -200,10 +202,9 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: instr = copy.deepcopy(orig_instr) calibrator = make_calibrator(instr) - calibrator.relative_constraints = RelativeConstraints.system + calibrator.relative_constraints_type = RelativeConstraintsType.system orig_center = instr.mean_detector_center - orig_tilt = instr.mean_detector_tilt orig_tvecs = {k: v.tvec for k, v in instr.detectors.items()} orig_rmats = {k: v.rmat for k, v in instr.detectors.items()} @@ -228,8 +229,6 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: # The new center should not match assert not np.allclose(orig_center, instr.mean_detector_center) - # The new tilt should match - assert np.allclose(orig_tilt, instr.mean_detector_tilt) # Find new translations and rmats new_relative_translations = compute_relative_translations(instr) @@ -253,7 +252,7 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: instr = copy.deepcopy(orig_instr) calibrator = make_calibrator(instr) - calibrator.relative_constraints = RelativeConstraints.system + calibrator.relative_constraints_type = RelativeConstraintsType.system orig_center = instr.mean_detector_center orig_tvecs = {k: v.tvec for k, v in instr.detectors.items()} @@ -267,22 +266,24 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: for tilt_name in system_tilt_names: calibrator.params[tilt_name].vary = True + orig_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + # Run the calibration calibrator.run_calibration(calibration_options) # The new center should match assert np.allclose(orig_center, instr.mean_detector_center) - # The new tilt should not match - assert not np.allclose(orig_tilt, instr.mean_detector_tilt) - - tilt_rmat_diff = ( - rotMatOfExpMap(instr.mean_detector_tilt) @ rotMatOfExpMap(orig_tilt).T - ) # Find new translations and rmats new_relative_translations = compute_relative_translations(instr) new_relative_rmats = compute_relative_rmats(instr) + new_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + + tilt_rmat_diff = ( + rotMatOfExpMap(new_system_tilt) @ rotMatOfExpMap(orig_system_tilt).T + ) + # absolute and relative tvecs should not match # absolute rmat should not match, but relative rmat should for key, panel in instr.detectors.items(): @@ -308,10 +309,9 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: instr = copy.deepcopy(orig_instr) calibrator = make_calibrator(instr) - calibrator.relative_constraints = RelativeConstraints.system + calibrator.relative_constraints_type = RelativeConstraintsType.system orig_center = instr.mean_detector_center - orig_tilt = instr.mean_detector_tilt orig_tvecs = {k: v.tvec for k, v in instr.detectors.items()} orig_rmats = {k: v.rmat for k, v in instr.detectors.items()} @@ -322,22 +322,24 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: for tilt_name in system_tilt_names: calibrator.params[tilt_name].vary = True + orig_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + # Run the calibration calibrator.run_calibration(calibration_options) # The new center should be different assert not np.allclose(orig_center, instr.mean_detector_center) - # The new tilt should not match - assert not np.allclose(orig_tilt, instr.mean_detector_tilt) - - tilt_rmat_diff = ( - rotMatOfExpMap(instr.mean_detector_tilt) @ rotMatOfExpMap(orig_tilt).T - ) # Find new translations and rmats new_relative_translations = compute_relative_translations(instr) new_relative_rmats = compute_relative_rmats(instr) + new_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + + tilt_rmat_diff = ( + rotMatOfExpMap(new_system_tilt) @ rotMatOfExpMap(orig_system_tilt).T + ) + # absolute and relative tvecs should not match # absolute rmat should not match, but relative rmat should for key, panel in instr.detectors.items():