From ad7f7413e0ec64dcb5b7381682316e1ec37f098c Mon Sep 17 00:00:00 2001 From: Patrick Avery Date: Tue, 15 Oct 2024 16:17:37 -0500 Subject: [PATCH] Clean up relative methods into classes The classes keep track of the current values of the relative parameters (before they were modified by lmfit). This is necessary for modifying all detectors by the diff of the change. Signed-off-by: Patrick Avery --- hexrd/fitting/calibration/__init__.py | 2 - hexrd/fitting/calibration/instrument.py | 30 ++++- .../calibration/lmfit_param_handling.py | 112 +++++++++++------- .../calibration/relative_constraints.py | 101 ++++++++++++++++ hexrd/fitting/calibration/structureless.py | 35 ++++-- hexrd/instrument/hedm_instrument.py | 22 ---- .../test_2xrs_calibration.py | 0 tests/{ => calibration}/test_calibration.py | 0 .../calibration/test_relative_constraints.py | 42 +++---- 9 files changed, 245 insertions(+), 99 deletions(-) create mode 100644 hexrd/fitting/calibration/relative_constraints.py rename tests/{ => calibration}/test_2xrs_calibration.py (100%) rename tests/{ => calibration}/test_calibration.py (100%) diff --git a/hexrd/fitting/calibration/__init__.py b/hexrd/fitting/calibration/__init__.py index 77aa739b2..c14026f3a 100644 --- a/hexrd/fitting/calibration/__init__.py +++ b/hexrd/fitting/calibration/__init__.py @@ -1,6 +1,5 @@ from .instrument import InstrumentCalibrator from .laue import LaueCalibrator -from .lmfit_param_handling import RelativeConstraints from .multigrain import calibrate_instrument_from_sx, generate_parameter_names from .powder import PowderCalibrator from .structureless import StructurelessCalibrator @@ -14,7 +13,6 @@ 'InstrumentCalibrator', 'LaueCalibrator', 'PowderCalibrator', - 'RelativeConstraints', 'StructurelessCalibrator', 'StructureLessCalibrator', ] diff --git a/hexrd/fitting/calibration/instrument.py b/hexrd/fitting/calibration/instrument.py index 8efec3ac1..06c521782 100644 --- a/hexrd/fitting/calibration/instrument.py +++ b/hexrd/fitting/calibration/instrument.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import lmfit import numpy as np @@ -9,7 +10,11 @@ DEFAULT_EULER_CONVENTION, update_instrument_from_params, validate_params_list, +) +from .relative_constraints import ( + create_relative_constraints, RelativeConstraints, + RelativeConstraintsType, ) logger = logging.getLogger() @@ -24,7 +29,7 @@ class InstrumentCalibrator: def __init__(self, *args, engineering_constraints=None, set_refinements_from_instrument_flags=True, euler_convention=DEFAULT_EULER_CONVENTION, - relative_constraints=RelativeConstraints.none): + relative_constraints_type=RelativeConstraintsType.none): """ Model for instrument calibration class as a function of @@ -47,7 +52,8 @@ def __init__(self, *args, engineering_constraints=None, assert calib.instr is self.instr, \ "all calibrators must refer to the same instrument" self._engineering_constraints = engineering_constraints - self._relative_constraints = relative_constraints + self._relative_constraints = create_relative_constraints( + relative_constraints_type, self.instr) self.euler_convention = euler_convention self.params = self.make_lmfit_params() @@ -164,18 +170,32 @@ def engineering_constraints(self, v): self._engineering_constraints = v self.params = self.make_lmfit_params() + @property + def relative_constraints_type(self): + return self._relative_constraints.type + + @relative_constraints_type.setter + def relative_constraints_type(self, v: Optional[RelativeConstraintsType]): + v = v if v is not None else RelativeConstraintsType.none + + current = getattr(self, '_relative_constraints', None) + if current is None or current.type != v: + self.relative_constraints = create_relative_constraints( + v, self.instr) + @property def relative_constraints(self) -> RelativeConstraints: return self._relative_constraints @relative_constraints.setter def relative_constraints(self, v: RelativeConstraints): - if v == self._relative_constraints: - return - self._relative_constraints = v self.params = self.make_lmfit_params() + def reset_relative_constraint_params(self): + # Set them back to zero. + self.relative_constraints.reset() + def run_calibration(self, odict): resd0 = self.residual() nrm_ssr_0 = _normalized_ssqr(resd0) diff --git a/hexrd/fitting/calibration/lmfit_param_handling.py b/hexrd/fitting/calibration/lmfit_param_handling.py index 5ea113c9d..1a667f522 100644 --- a/hexrd/fitting/calibration/lmfit_param_handling.py +++ b/hexrd/fitting/calibration/lmfit_param_handling.py @@ -1,4 +1,4 @@ -from enum import Enum +from typing import Optional import lmfit import numpy as np @@ -17,24 +17,18 @@ rotMatOfExpMap, ) from hexrd.material.unitcell import _lpname +from .relative_constraints import ( + RelativeConstraints, + RelativeConstraintsType, +) # First is the axes_order, second is extrinsic DEFAULT_EULER_CONVENTION = ('zxz', False) -class RelativeConstraints(Enum): - """These are relative constraints between the detectors""" - # 'none' means no relative constraints - none = 'None' - # 'group' means constrain tilts/translations within a group - group = 'Group' - # 'system' means constrain tilts/translations within the whole system - system = 'System' - - def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION, - relative_constraints=RelativeConstraints.none): + relative_constraints=None): # add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP) parms_list = [] @@ -62,23 +56,27 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION, parms_list.append(('instr_tvec_y', instr.tvec[1], False, -np.inf, np.inf)) parms_list.append(('instr_tvec_z', instr.tvec[2], False, -np.inf, np.inf)) - if relative_constraints == RelativeConstraints.none: + if ( + relative_constraints is None or + relative_constraints.type == RelativeConstraintsType.none + ): add_unconstrained_detector_parameters( instr, euler_convention, parms_list, ) - elif relative_constraints == RelativeConstraints.group: + elif relative_constraints.type == RelativeConstraintsType.group: # This should be implemented soon - raise NotImplementedError(relative_constraints) - elif relative_constraints == RelativeConstraints.system: + raise NotImplementedError(relative_constraints.type) + elif relative_constraints.type == RelativeConstraintsType.system: add_system_constrained_detector_parameters( instr, euler_convention, parms_list, + relative_constraints, ) else: - raise NotImplementedError(relative_constraints) + raise NotImplementedError(relative_constraints.type) return parms_list @@ -122,10 +120,24 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list): -np.inf, np.inf)) -def add_system_constrained_detector_parameters(instr, euler_convention, - parms_list): - mean_center = instr.mean_detector_center - mean_tilt = instr.mean_detector_tilt +def add_system_constrained_detector_parameters( + instr, euler_convention, + parms_list, relative_constraints: RelativeConstraints): + system_params = relative_constraints.params + system_tvec = system_params['translation'] + system_tilt = system_params['tilt'] + + if euler_convention is not None: + # Convert the tilt to the specified Euler convention + normalized = normalize_euler_convention(euler_convention) + rme = RotMatEuler( + np.zeros(3,), + axes_order=normalized[0], + extrinsic=normalized[1], + ) + + rme.rmat = _tilt_to_rmat(system_tilt, None) + system_tilt = np.degrees(rme.angles) tvec_names = [ 'system_tvec_x', @@ -138,12 +150,12 @@ def add_system_constrained_detector_parameters(instr, euler_convention, tilt_deltas = [2, 2, 2] for i, name in enumerate(tvec_names): - value = mean_center[i] + value = system_tvec[i] delta = tvec_deltas[i] parms_list.append((name, value, True, value - delta, value + delta)) for i, name in enumerate(tilt_names): - value = mean_tilt[i] + value = system_tilt[i] delta = tilt_deltas[i] parms_list.append((name, value, True, value - delta, value + delta)) @@ -160,8 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]: return param_names -def update_instrument_from_params(instr, params, euler_convention, - relative_constraints): +def update_instrument_from_params( + instr, params, + euler_convention=DEFAULT_EULER_CONVENTION, + relative_constraints: Optional[RelativeConstraints] = None): """ this function updates the instrument from the lmfit parameter list. we don't have to keep track @@ -196,23 +210,27 @@ def update_instrument_from_params(instr, params, euler_convention, params['instr_tvec_z'].value] instr.tvec = np.r_[instr_tvec] - if relative_constraints == RelativeConstraints.none: + if ( + relative_constraints is None or + relative_constraints.type == RelativeConstraintsType.none + ): update_unconstrained_detector_parameters( instr, params, euler_convention, ) - elif relative_constraints == RelativeConstraints.group: + elif relative_constraints.type == RelativeConstraintsType.group: # This should be implemented soon - raise NotImplementedError(relative_constraints) - elif relative_constraints == RelativeConstraints.system: + raise NotImplementedError(relative_constraints.type) + elif relative_constraints.type == RelativeConstraintsType.system: update_system_constrained_detector_parameters( instr, params, euler_convention, + relative_constraints, ) else: - raise NotImplementedError(relative_constraints) + raise NotImplementedError(relative_constraints.type) def update_unconstrained_detector_parameters(instr, params, euler_convention): @@ -245,10 +263,15 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention): ) -def update_system_constrained_detector_parameters(instr, params, euler_convention): - # We will always rotate/translate about the center of the group +def update_system_constrained_detector_parameters( + instr, params, euler_convention, + relative_constraints: RelativeConstraints): + # We will always rotate about the center of the detectors mean_center = instr.mean_detector_center - mean_tilt = instr.mean_detector_tilt + + system_params = relative_constraints.params + system_tvec = system_params['translation'] + system_tilt = system_params['tilt'] tvec_names = [ 'system_tvec_x', @@ -263,11 +286,11 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio if any(params[x].vary for x in tilt_names): # Find the change in tilt, create an rmat, then apply to detector tilts # and translations. - new_mean_tilt = np.array([params[x].value for x in tilt_names]) + new_system_tilt = np.array([params[x].value for x in tilt_names]) - # The old mean tilt was in the None convention - old_rmat = _tilt_to_rmat(mean_tilt, None) - new_rmat = _tilt_to_rmat(new_mean_tilt, euler_convention) + # The old system tilt was in the None convention + old_rmat = _tilt_to_rmat(system_tilt, None) + new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention) # Compute the rmat used to convert from old to new rmat_diff = new_rmat @ old_rmat.T @@ -276,19 +299,26 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio for panel in instr.detectors.values(): panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat) - # Also rotate the detectors about the center + # Also rotate the detectors about the mean center panel.tvec = rmat_diff @ (panel.tvec - mean_center) + mean_center + # Update the system tilt + system_tilt[:] = _rmat_to_tilt(new_rmat) + if any(params[x].vary for x in tvec_names): # Find the change in center and shift all tvecs - new_mean_center = np.array([params[x].value for x in tvec_names]) + new_system_tvec = np.array([params[x].value for x in tvec_names]) - diff = new_mean_center - mean_center + diff = new_system_tvec - system_tvec for panel in instr.detectors.values(): panel.tvec += diff + # Update the system tvec + system_tvec[:] = new_system_tvec + -def _tilt_to_rmat(tilt: np.ndarray, euler_convention: dict | tuple) -> np.ndarray: +def _tilt_to_rmat(tilt: np.ndarray, + euler_convention: dict | tuple) -> np.ndarray: # Convert the tilt to exponential map parameters, and then # to the rotation matrix, and return. if euler_convention is None: diff --git a/hexrd/fitting/calibration/relative_constraints.py b/hexrd/fitting/calibration/relative_constraints.py new file mode 100644 index 000000000..524d14d04 --- /dev/null +++ b/hexrd/fitting/calibration/relative_constraints.py @@ -0,0 +1,101 @@ +from abc import ABC, abstractmethod +from enum import Enum + +import numpy as np + +from hexrd.instrument import HEDMInstrument + + +class RelativeConstraintsType(Enum): + """These are relative constraints between the detectors""" + # 'none' means no relative constraints + none = 'None' + # 'group' means constrain tilts/translations within a group + group = 'Group' + # 'system' means constrain tilts/translations within the whole system + system = 'System' + + +class RelativeConstraints(ABC): + @property + @abstractmethod + def type(self) -> RelativeConstraintsType: + pass + + @property + @abstractmethod + def params(self) -> dict: + pass + + @abstractmethod + def reset(self): + # Reset the parameters + pass + + +class RelativeConstraintsNone(RelativeConstraints): + type = RelativeConstraintsType.none + + @property + def params(self) -> dict: + return {} + + def reset(self): + pass + + +class RelativeConstraintsGroup(RelativeConstraints): + type = RelativeConstraintsType.group + + def __init__(self, instr: HEDMInstrument): + self._groups = [] + for panel in instr.detectors.values(): + if panel.group is not None and panel.group not in self._groups: + self._groups.append(panel.group) + + self.reset() + + def reset(self): + self.group_params = {} + + for group in self._groups: + self.group_params[group] = { + 'tilt': np.array([0, 0, 0], dtype=float), + 'translation': np.array([0, 0, 0], dtype=float), + } + + @property + def params(self) -> dict: + return self.group_params + + +class RelativeConstraintsSystem(RelativeConstraints): + type = RelativeConstraintsType.system + + def __init__(self): + self.reset() + + @property + def params(self) -> dict: + return self._params + + def reset(self): + self._params = { + 'tilt': np.array([0, 0, 0], dtype=float), + 'translation': np.array([0, 0, 0], dtype=float), + } + + +def create_relative_constraints(type: RelativeConstraintsType, + instr: HEDMInstrument): + types = { + 'None': RelativeConstraintsNone, + 'Group': RelativeConstraintsGroup, + 'System': RelativeConstraintsSystem, + } + + kwargs = {} + if type == 'System': + kwargs['instr'] = instr + + return types[type.value](**kwargs) diff --git a/hexrd/fitting/calibration/structureless.py b/hexrd/fitting/calibration/structureless.py index c82837b3c..4bc743594 100644 --- a/hexrd/fitting/calibration/structureless.py +++ b/hexrd/fitting/calibration/structureless.py @@ -1,4 +1,6 @@ import copy +from typing import Optional + import lmfit import numpy as np @@ -9,10 +11,14 @@ create_instr_params, create_tth_parameters, DEFAULT_EULER_CONVENTION, - RelativeConstraints, tth_parameter_prefixes, update_instrument_from_params, ) +from .relative_constraints import ( + create_relative_constraints, + RelativeConstraints, + RelativeConstraintsType, +) class StructurelessCalibrator: @@ -39,14 +45,15 @@ def __init__(self, data, tth_distortion=None, engineering_constraints=None, - relative_constraints=RelativeConstraints.none, + relative_constraints_type=RelativeConstraintsType.none, euler_convention=DEFAULT_EULER_CONVENTION): self._instr = instr self._data = data self._tth_distortion = tth_distortion self._engineering_constraints = engineering_constraints - self._relative_constraints = relative_constraints + self._relative_constraints = create_relative_constraints( + relative_constraints_type, self.instr) self.euler_convention = euler_convention self._update_tth_distortion_panels() self.make_lmfit_params() @@ -163,16 +170,26 @@ def _update_tth_distortion_panels(self): obj.panel = self.instr.detectors[det_key] @property - def relative_constraints(self): + def relative_constraints_type(self): + return self._relative_constraints.type + + @relative_constraints_type.setter + def relative_constraints_type(self, v: Optional[RelativeConstraintsType]): + v = v if v is not None else RelativeConstraintsType.none + + current = getattr(self, '_relative_constraints', None) + if current is None or current.type != v: + self.relative_constraints = create_relative_constraints( + v, self.instr) + + @property + def relative_constraints(self) -> RelativeConstraints: return self._relative_constraints @relative_constraints.setter - def relative_constraints(self, v): - if v == self._relative_constraints: - return - + def relative_constraints(self, v: RelativeConstraints): self._relative_constraints = v - self.make_lmfit_params() + self.params = self.make_lmfit_params() @property def engineering_constraints(self): diff --git a/hexrd/instrument/hedm_instrument.py b/hexrd/instrument/hedm_instrument.py index 53cc64bc2..5d38e1a8d 100644 --- a/hexrd/instrument/hedm_instrument.py +++ b/hexrd/instrument/hedm_instrument.py @@ -813,28 +813,6 @@ def mean_group_centers(self) -> dict[str, np.ndarray]: return {k: v.sum(axis=0) / len(v) for k, v in centers.items()} - @property - def mean_detector_tilt(self) -> np.ndarray: - """Return the mean tilt for all detectors""" - tilts = np.array([panel.tilt for panel in self.detectors.values()]) - return tilts.sum(axis=0) / len(tilts) - - @property - def mean_group_tilts(self) -> dict[str, np.ndarray]: - """Return the mean tilt for every group of detectors""" - tilts = {} - for panel in self.detectors.values(): - if panel.group is None: - # Skip over panels without groups - continue - - if panel.group not in tilts: - tilts[panel.group] = [] - - tilts[panel.group].append(panel.tilt) - - return {k: v.sum(axis=0) / len(v) for k, v in tilts.items()} - # properties for physical size of rectangular detector @property def id(self): diff --git a/tests/test_2xrs_calibration.py b/tests/calibration/test_2xrs_calibration.py similarity index 100% rename from tests/test_2xrs_calibration.py rename to tests/calibration/test_2xrs_calibration.py diff --git a/tests/test_calibration.py b/tests/calibration/test_calibration.py similarity index 100% rename from tests/test_calibration.py rename to tests/calibration/test_calibration.py diff --git a/tests/calibration/test_relative_constraints.py b/tests/calibration/test_relative_constraints.py index 2cf6b71d9..6b331e82b 100644 --- a/tests/calibration/test_relative_constraints.py +++ b/tests/calibration/test_relative_constraints.py @@ -9,7 +9,9 @@ from hexrd.fitting.calibration import ( InstrumentCalibrator, PowderCalibrator, - RelativeConstraints, +) +from hexrd.fitting.calibration.relative_constraints import ( + RelativeConstraintsType, ) from hexrd.imageseries.process import ProcessedImageSeries from hexrd.instrument import HEDMInstrument @@ -200,10 +202,9 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: instr = copy.deepcopy(orig_instr) calibrator = make_calibrator(instr) - calibrator.relative_constraints = RelativeConstraints.system + calibrator.relative_constraints_type = RelativeConstraintsType.system orig_center = instr.mean_detector_center - orig_tilt = instr.mean_detector_tilt orig_tvecs = {k: v.tvec for k, v in instr.detectors.items()} orig_rmats = {k: v.rmat for k, v in instr.detectors.items()} @@ -228,8 +229,6 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: # The new center should not match assert not np.allclose(orig_center, instr.mean_detector_center) - # The new tilt should match - assert np.allclose(orig_tilt, instr.mean_detector_tilt) # Find new translations and rmats new_relative_translations = compute_relative_translations(instr) @@ -253,7 +252,7 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: instr = copy.deepcopy(orig_instr) calibrator = make_calibrator(instr) - calibrator.relative_constraints = RelativeConstraints.system + calibrator.relative_constraints_type = RelativeConstraintsType.system orig_center = instr.mean_detector_center orig_tvecs = {k: v.tvec for k, v in instr.detectors.items()} @@ -267,22 +266,24 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: for tilt_name in system_tilt_names: calibrator.params[tilt_name].vary = True + orig_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + # Run the calibration calibrator.run_calibration(calibration_options) # The new center should match assert np.allclose(orig_center, instr.mean_detector_center) - # The new tilt should not match - assert not np.allclose(orig_tilt, instr.mean_detector_tilt) - - tilt_rmat_diff = ( - rotMatOfExpMap(instr.mean_detector_tilt) @ rotMatOfExpMap(orig_tilt).T - ) # Find new translations and rmats new_relative_translations = compute_relative_translations(instr) new_relative_rmats = compute_relative_rmats(instr) + new_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + + tilt_rmat_diff = ( + rotMatOfExpMap(new_system_tilt) @ rotMatOfExpMap(orig_system_tilt).T + ) + # absolute and relative tvecs should not match # absolute rmat should not match, but relative rmat should for key, panel in instr.detectors.items(): @@ -308,10 +309,9 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: instr = copy.deepcopy(orig_instr) calibrator = make_calibrator(instr) - calibrator.relative_constraints = RelativeConstraints.system + calibrator.relative_constraints_type = RelativeConstraintsType.system orig_center = instr.mean_detector_center - orig_tilt = instr.mean_detector_tilt orig_tvecs = {k: v.tvec for k, v in instr.detectors.items()} orig_rmats = {k: v.rmat for k, v in instr.detectors.items()} @@ -322,22 +322,24 @@ def compute_relative_rmats(instr: HEDMInstrument) -> dict: for tilt_name in system_tilt_names: calibrator.params[tilt_name].vary = True + orig_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + # Run the calibration calibrator.run_calibration(calibration_options) # The new center should be different assert not np.allclose(orig_center, instr.mean_detector_center) - # The new tilt should not match - assert not np.allclose(orig_tilt, instr.mean_detector_tilt) - - tilt_rmat_diff = ( - rotMatOfExpMap(instr.mean_detector_tilt) @ rotMatOfExpMap(orig_tilt).T - ) # Find new translations and rmats new_relative_translations = compute_relative_translations(instr) new_relative_rmats = compute_relative_rmats(instr) + new_system_tilt = calibrator.relative_constraints.params['tilt'].copy() + + tilt_rmat_diff = ( + rotMatOfExpMap(new_system_tilt) @ rotMatOfExpMap(orig_system_tilt).T + ) + # absolute and relative tvecs should not match # absolute rmat should not match, but relative rmat should for key, panel in instr.detectors.items():