Skip to content

Commit

Permalink
Merge pull request #723 from HEXRD/relative-constraints
Browse files Browse the repository at this point in the history
Add Relative Constraints to Calibration
  • Loading branch information
psavery authored Oct 22, 2024
2 parents 45847c9 + ad7f741 commit 8c669ab
Show file tree
Hide file tree
Showing 8 changed files with 738 additions and 6 deletions.
39 changes: 38 additions & 1 deletion hexrd/fitting/calibration/instrument.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional

import lmfit
import numpy as np
Expand All @@ -10,6 +11,11 @@
update_instrument_from_params,
validate_params_list,
)
from .relative_constraints import (
create_relative_constraints,
RelativeConstraints,
RelativeConstraintsType,
)

logger = logging.getLogger()
logger.setLevel('INFO')
Expand All @@ -22,7 +28,8 @@ def _normalized_ssqr(resd):
class InstrumentCalibrator:
def __init__(self, *args, engineering_constraints=None,
set_refinements_from_instrument_flags=True,
euler_convention=DEFAULT_EULER_CONVENTION):
euler_convention=DEFAULT_EULER_CONVENTION,
relative_constraints_type=RelativeConstraintsType.none):
"""
Model for instrument calibration class as a function of
Expand All @@ -45,6 +52,8 @@ def __init__(self, *args, engineering_constraints=None,
assert calib.instr is self.instr, \
"all calibrators must refer to the same instrument"
self._engineering_constraints = engineering_constraints
self._relative_constraints = create_relative_constraints(
relative_constraints_type, self.instr)
self.euler_convention = euler_convention

self.params = self.make_lmfit_params()
Expand All @@ -59,6 +68,7 @@ def make_lmfit_params(self):
params = create_instr_params(
self.instr,
euler_convention=self.euler_convention,
relative_constraints=self.relative_constraints,
)

for calibrator in self.calibrators:
Expand All @@ -82,6 +92,7 @@ def update_all_from_params(self, params):
self.instr,
params,
self.euler_convention,
self.relative_constraints,
)

for calibrator in self.calibrators:
Expand Down Expand Up @@ -159,6 +170,32 @@ def engineering_constraints(self, v):
self._engineering_constraints = v
self.params = self.make_lmfit_params()

@property
def relative_constraints_type(self):
return self._relative_constraints.type

@relative_constraints_type.setter
def relative_constraints_type(self, v: Optional[RelativeConstraintsType]):
v = v if v is not None else RelativeConstraintsType.none

current = getattr(self, '_relative_constraints', None)
if current is None or current.type != v:
self.relative_constraints = create_relative_constraints(
v, self.instr)

@property
def relative_constraints(self) -> RelativeConstraints:
return self._relative_constraints

@relative_constraints.setter
def relative_constraints(self, v: RelativeConstraints):
self._relative_constraints = v
self.params = self.make_lmfit_params()

def reset_relative_constraint_params(self):
# Set them back to zero.
self.relative_constraints.reset()

def run_calibration(self, odict):
resd0 = self.residual()
nrm_ssr_0 = _normalized_ssqr(resd0)
Expand Down
182 changes: 178 additions & 4 deletions hexrd/fitting/calibration/lmfit_param_handling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import lmfit
import numpy as np

Expand All @@ -7,19 +9,26 @@
HEDMInstrument,
)
from hexrd.rotations import (
angleAxisOfRotMat,
expMapOfQuat,
make_rmat_euler,
quatOfRotMat,
RotMatEuler,
rotMatOfExpMap,
)
from hexrd.material.unitcell import _lpname
from .relative_constraints import (
RelativeConstraints,
RelativeConstraintsType,
)


# First is the axes_order, second is extrinsic
DEFAULT_EULER_CONVENTION = ('zxz', False)


def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION):
def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
relative_constraints=None):
# add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP)
parms_list = []

Expand All @@ -46,6 +55,33 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION):
parms_list.append(('instr_tvec_x', instr.tvec[0], False, -np.inf, np.inf))
parms_list.append(('instr_tvec_y', instr.tvec[1], False, -np.inf, np.inf))
parms_list.append(('instr_tvec_z', instr.tvec[2], False, -np.inf, np.inf))

if (
relative_constraints is None or
relative_constraints.type == RelativeConstraintsType.none
):
add_unconstrained_detector_parameters(
instr,
euler_convention,
parms_list,
)
elif relative_constraints.type == RelativeConstraintsType.group:
# This should be implemented soon
raise NotImplementedError(relative_constraints.type)
elif relative_constraints.type == RelativeConstraintsType.system:
add_system_constrained_detector_parameters(
instr,
euler_convention,
parms_list,
relative_constraints,
)
else:
raise NotImplementedError(relative_constraints.type)

return parms_list


def add_unconstrained_detector_parameters(instr, euler_convention, parms_list):
for det_name, panel in instr.detectors.items():
det = det_name.replace('-', '_')

Expand Down Expand Up @@ -83,7 +119,45 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION):
parms_list.append((f'{det}_radius', panel.radius, False,
-np.inf, np.inf))

return parms_list

def add_system_constrained_detector_parameters(
instr, euler_convention,
parms_list, relative_constraints: RelativeConstraints):
system_params = relative_constraints.params
system_tvec = system_params['translation']
system_tilt = system_params['tilt']

if euler_convention is not None:
# Convert the tilt to the specified Euler convention
normalized = normalize_euler_convention(euler_convention)
rme = RotMatEuler(
np.zeros(3,),
axes_order=normalized[0],
extrinsic=normalized[1],
)

rme.rmat = _tilt_to_rmat(system_tilt, None)
system_tilt = np.degrees(rme.angles)

tvec_names = [
'system_tvec_x',
'system_tvec_y',
'system_tvec_z',
]
tvec_deltas = [1, 1, 1]

tilt_names = param_names_euler_convention('system', euler_convention)
tilt_deltas = [2, 2, 2]

for i, name in enumerate(tvec_names):
value = system_tvec[i]
delta = tvec_deltas[i]
parms_list.append((name, value, True, value - delta, value + delta))

for i, name in enumerate(tilt_names):
value = system_tilt[i]
delta = tilt_deltas[i]
parms_list.append((name, value, True, value - delta, value + delta))


def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
Expand All @@ -98,7 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
return param_names


def update_instrument_from_params(instr, params, euler_convention):
def update_instrument_from_params(
instr, params,
euler_convention=DEFAULT_EULER_CONVENTION,
relative_constraints: Optional[RelativeConstraints] = None):
"""
this function updates the instrument from the
lmfit parameter list. we don't have to keep track
Expand Down Expand Up @@ -133,6 +210,30 @@ def update_instrument_from_params(instr, params, euler_convention):
params['instr_tvec_z'].value]
instr.tvec = np.r_[instr_tvec]

if (
relative_constraints is None or
relative_constraints.type == RelativeConstraintsType.none
):
update_unconstrained_detector_parameters(
instr,
params,
euler_convention,
)
elif relative_constraints.type == RelativeConstraintsType.group:
# This should be implemented soon
raise NotImplementedError(relative_constraints.type)
elif relative_constraints.type == RelativeConstraintsType.system:
update_system_constrained_detector_parameters(
instr,
params,
euler_convention,
relative_constraints,
)
else:
raise NotImplementedError(relative_constraints.type)


def update_unconstrained_detector_parameters(instr, params, euler_convention):
for det_name, detector in instr.detectors.items():
det = det_name.replace('-', '_')
set_detector_angles_euler(detector, det, params, euler_convention)
Expand Down Expand Up @@ -162,13 +263,86 @@ def update_instrument_from_params(instr, params, euler_convention):
)


def update_system_constrained_detector_parameters(
instr, params, euler_convention,
relative_constraints: RelativeConstraints):
# We will always rotate about the center of the detectors
mean_center = instr.mean_detector_center

system_params = relative_constraints.params
system_tvec = system_params['translation']
system_tilt = system_params['tilt']

tvec_names = [
'system_tvec_x',
'system_tvec_y',
'system_tvec_z',
]
tilt_names = param_names_euler_convention('system', euler_convention)

# Just like the detectors, we will apply tilt first and then translation
# Only apply these transforms if they were marked "Vary".

if any(params[x].vary for x in tilt_names):
# Find the change in tilt, create an rmat, then apply to detector tilts
# and translations.
new_system_tilt = np.array([params[x].value for x in tilt_names])

# The old system tilt was in the None convention
old_rmat = _tilt_to_rmat(system_tilt, None)
new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention)

# Compute the rmat used to convert from old to new
rmat_diff = new_rmat @ old_rmat.T

# Rotate each detector using the rmat_diff
for panel in instr.detectors.values():
panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat)

# Also rotate the detectors about the mean center
panel.tvec = rmat_diff @ (panel.tvec - mean_center) + mean_center

# Update the system tilt
system_tilt[:] = _rmat_to_tilt(new_rmat)

if any(params[x].vary for x in tvec_names):
# Find the change in center and shift all tvecs
new_system_tvec = np.array([params[x].value for x in tvec_names])

diff = new_system_tvec - system_tvec
for panel in instr.detectors.values():
panel.tvec += diff

# Update the system tvec
system_tvec[:] = new_system_tvec


def _tilt_to_rmat(tilt: np.ndarray,
euler_convention: dict | tuple) -> np.ndarray:
# Convert the tilt to exponential map parameters, and then
# to the rotation matrix, and return.
if euler_convention is None:
return rotMatOfExpMap(tilt)

normalized = normalize_euler_convention(euler_convention)
return make_rmat_euler(
np.radians(tilt),
axes_order=normalized[0],
extrinsic=normalized[1],
)


def _rmat_to_tilt(rmat: np.ndarray) -> np.ndarray:
phi, n = angleAxisOfRotMat(rmat)
return phi * n.flatten()


def create_tth_parameters(
instr: HEDMInstrument,
meas_angles: dict[str, np.ndarray],
) -> list[lmfit.Parameter]:

prefixes = tth_parameter_prefixes(instr)

parms_list = []
for xray_source, angles in meas_angles.items():
prefix = prefixes[xray_source]
Expand Down
Loading

0 comments on commit 8c669ab

Please sign in to comment.