diff --git a/glassure/core/calc.py b/glassure/core/calc.py index c6113e7..843f582 100644 --- a/glassure/core/calc.py +++ b/glassure/core/calc.py @@ -7,6 +7,8 @@ from .utility import calculate_incoherent_scattering, calculate_f_squared_mean, calculate_f_mean_squared, \ convert_density_to_atoms_per_cubic_angstrom +from .methods import SqMethod, NormalizationMethod, FourierTransformMethod + __all__ = ['calculate_normalization_factor_raw', 'calculate_normalization_factor', 'fit_normalization_factor', 'calculate_sq', 'calculate_sq_raw', 'calculate_sq_from_fr', 'calculate_sq_from_gr', 'calculate_fr', 'calculate_gr_raw', 'calculate_gr'] @@ -139,10 +141,10 @@ def calculate_sq_raw(sample_pattern: Pattern, f_squared_mean: np.ndarray, f_mean if incoherent_scattering is None: incoherent_scattering = np.zeros_like(q) - if method == 'FZ': + if method == 'FZ' or method == SqMethod.FZ: sq = (normalization_factor * intensity - incoherent_scattering - f_squared_mean + f_mean_squared) / \ f_mean_squared - elif method == 'AL': + elif method == 'AL' or method == SqMethod.AL: sq = (normalization_factor * intensity - incoherent_scattering) / f_squared_mean else: raise NotImplementedError('{} method is not implemented'.format(method)) @@ -188,7 +190,7 @@ def calculate_sq(sample_pattern: Pattern, density: float, composition: dict[str, incoherent_scattering = None atomic_density = convert_density_to_atoms_per_cubic_angstrom(composition, density) - if normalization_method == 'fit': + if normalization_method == 'fit' or normalization_method == NormalizationMethod.FIT: normalization_factor = fit_normalization_factor(sample_pattern, composition, use_incoherent_scattering, @@ -238,9 +240,9 @@ def calculate_fr(sq_pattern: Pattern, r: Optional[np.ndarray] = None, use_modifi else: modification = 1 - if method == 'integral': + if method == 'integral' or method == FourierTransformMethod.INTEGRAL: fr = 2.0 / np.pi * np.trapz(modification * q * (sq - 1) * np.array(np.sin(np.outer(q.T, r))).T, q) - elif method == 'fft': + elif method == 'fft' or method == FourierTransformMethod.FFT: q_step = q[1] - q[0] r_step = r[1] - r[0] diff --git a/glassure/core/methods.py b/glassure/core/methods.py new file mode 100644 index 0000000..f2345d2 --- /dev/null +++ b/glassure/core/methods.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class SqMethod(Enum): + """ + Enum class for the different methods to calculate the structure factor. + """ + FZ = 'FZ' + AL = 'AL' + + +class NormalizationMethod(Enum): + """ + Enum class for the different methods to perform an intensity normalization. + """ + INTEGRAL = 'integral' + FIT = 'fit' + + +class FourierTransformMethod(Enum): + """ + Enum class for the different methods to perform a Fourier transform. + """ + FFT = 'fft' + INTEGRAL = 'integral' diff --git a/glassure/gui/model/configuration.py b/glassure/gui/model/configuration.py index 54ae062..686906d 100644 --- a/glassure/gui/model/configuration.py +++ b/glassure/gui/model/configuration.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*/: from __future__ import annotations -from enum import StrEnum +from enum import Enum from colorsys import hsv_to_rgb from copy import deepcopy import numpy as np from ...core.pattern import Pattern +from ...core.methods import SqMethod, NormalizationMethod, FourierTransformMethod class TransformConfiguration(object): @@ -273,30 +274,6 @@ def get_pattern_or_none(pattern_dict): return config -class SqMethod(StrEnum): - """ - Enum class for the different methods to calculate the structure factor. - """ - FZ = 'FZ' - AL = 'AL' - - -class NormalizationMethod(StrEnum): - """ - Enum class for the different methods to perform an intensity normalization. - """ - INTEGRAL = 'integral' - FIT = 'fit' - - -class FourierTransformMethod(StrEnum): - """ - Enum class for the different methods to perform a Fourier transform. - """ - FFT = 'fft' - INTEGRAL = 'integral' - - def calculate_color(ind): s = 0.8 v = 0.8 diff --git a/glassure/gui/widgets/control/options.py b/glassure/gui/widgets/control/options.py index e716002..7ba4d7e 100644 --- a/glassure/gui/widgets/control/options.py +++ b/glassure/gui/widgets/control/options.py @@ -3,7 +3,8 @@ from qtpy import QtCore, QtWidgets from ..custom import HorizontalLine, FloatLineEdit, DragSlider import numpy as np -from ...model.configuration import NormalizationMethod, SqMethod, TransformConfiguration, FourierTransformMethod +from ...model.configuration import TransformConfiguration +from ....core.methods import SqMethod, NormalizationMethod, FourierTransformMethod class OptionsWidget(QtWidgets.QWidget): @@ -165,7 +166,7 @@ def get_fourier_transform_method(self): return FourierTransformMethod.INTEGRAL def set_fourier_transform_method(self, method): - if method == 'fft': + if method == 'fft' or method == FourierTransformMethod.FFT: self.fft_cb.setChecked(True) else: self.fft_cb.setChecked(False)