Skip to content

Commit

Permalink
refactor usage of StrEnum (python >3.11) to Enum, since we want to ke…
Browse files Browse the repository at this point in the history
…ep backward compatibility
  • Loading branch information
CPrescher committed Oct 27, 2023
1 parent 53ba2f7 commit 450b88a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 32 deletions.
12 changes: 7 additions & 5 deletions glassure/core/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .utility import calculate_incoherent_scattering, calculate_f_squared_mean, calculate_f_mean_squared, \
convert_density_to_atoms_per_cubic_angstrom

from .methods import SqMethod, NormalizationMethod, FourierTransformMethod

__all__ = ['calculate_normalization_factor_raw', 'calculate_normalization_factor', 'fit_normalization_factor',
'calculate_sq', 'calculate_sq_raw', 'calculate_sq_from_fr', 'calculate_sq_from_gr',
'calculate_fr', 'calculate_gr_raw', 'calculate_gr']
Expand Down Expand Up @@ -139,10 +141,10 @@ def calculate_sq_raw(sample_pattern: Pattern, f_squared_mean: np.ndarray, f_mean
if incoherent_scattering is None:
incoherent_scattering = np.zeros_like(q)

if method == 'FZ':
if method == 'FZ' or method == SqMethod.FZ:
sq = (normalization_factor * intensity - incoherent_scattering - f_squared_mean + f_mean_squared) / \
f_mean_squared
elif method == 'AL':
elif method == 'AL' or method == SqMethod.AL:
sq = (normalization_factor * intensity - incoherent_scattering) / f_squared_mean
else:
raise NotImplementedError('{} method is not implemented'.format(method))
Expand Down Expand Up @@ -188,7 +190,7 @@ def calculate_sq(sample_pattern: Pattern, density: float, composition: dict[str,
incoherent_scattering = None

atomic_density = convert_density_to_atoms_per_cubic_angstrom(composition, density)
if normalization_method == 'fit':
if normalization_method == 'fit' or normalization_method == NormalizationMethod.FIT:
normalization_factor = fit_normalization_factor(sample_pattern,
composition,
use_incoherent_scattering,
Expand Down Expand Up @@ -238,9 +240,9 @@ def calculate_fr(sq_pattern: Pattern, r: Optional[np.ndarray] = None, use_modifi
else:
modification = 1

if method == 'integral':
if method == 'integral' or method == FourierTransformMethod.INTEGRAL:
fr = 2.0 / np.pi * np.trapz(modification * q * (sq - 1) * np.array(np.sin(np.outer(q.T, r))).T, q)
elif method == 'fft':
elif method == 'fft' or method == FourierTransformMethod.FFT:
q_step = q[1] - q[0]
r_step = r[1] - r[0]

Expand Down
25 changes: 25 additions & 0 deletions glassure/core/methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from enum import Enum


class SqMethod(Enum):
"""
Enum class for the different methods to calculate the structure factor.
"""
FZ = 'FZ'
AL = 'AL'


class NormalizationMethod(Enum):
"""
Enum class for the different methods to perform an intensity normalization.
"""
INTEGRAL = 'integral'
FIT = 'fit'


class FourierTransformMethod(Enum):
"""
Enum class for the different methods to perform a Fourier transform.
"""
FFT = 'fft'
INTEGRAL = 'integral'
27 changes: 2 additions & 25 deletions glassure/gui/model/configuration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*/:
from __future__ import annotations
from enum import StrEnum
from enum import Enum
from colorsys import hsv_to_rgb
from copy import deepcopy

import numpy as np

from ...core.pattern import Pattern
from ...core.methods import SqMethod, NormalizationMethod, FourierTransformMethod


class TransformConfiguration(object):
Expand Down Expand Up @@ -273,30 +274,6 @@ def get_pattern_or_none(pattern_dict):
return config


class SqMethod(StrEnum):
"""
Enum class for the different methods to calculate the structure factor.
"""
FZ = 'FZ'
AL = 'AL'


class NormalizationMethod(StrEnum):
"""
Enum class for the different methods to perform an intensity normalization.
"""
INTEGRAL = 'integral'
FIT = 'fit'


class FourierTransformMethod(StrEnum):
"""
Enum class for the different methods to perform a Fourier transform.
"""
FFT = 'fft'
INTEGRAL = 'integral'


def calculate_color(ind):
s = 0.8
v = 0.8
Expand Down
5 changes: 3 additions & 2 deletions glassure/gui/widgets/control/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from qtpy import QtCore, QtWidgets
from ..custom import HorizontalLine, FloatLineEdit, DragSlider
import numpy as np
from ...model.configuration import NormalizationMethod, SqMethod, TransformConfiguration, FourierTransformMethod
from ...model.configuration import TransformConfiguration
from ....core.methods import SqMethod, NormalizationMethod, FourierTransformMethod


class OptionsWidget(QtWidgets.QWidget):
Expand Down Expand Up @@ -165,7 +166,7 @@ def get_fourier_transform_method(self):
return FourierTransformMethod.INTEGRAL

def set_fourier_transform_method(self, method):
if method == 'fft':
if method == 'fft' or method == FourierTransformMethod.FFT:
self.fft_cb.setChecked(True)
else:
self.fft_cb.setChecked(False)
Expand Down

0 comments on commit 450b88a

Please sign in to comment.