Skip to content

Commit

Permalink
Abstract out data interaction using setter and getter; allows to use …
Browse files Browse the repository at this point in the history
…same methods for classes with hdf5 refs
  • Loading branch information
ka-sarthak committed Sep 5, 2024
1 parent a9a8d00 commit 278e592
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 20 deletions.
32 changes: 32 additions & 0 deletions src/nomad_measurements/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os.path
from typing import (
TYPE_CHECKING,
Any,
)

import numpy as np
Expand Down Expand Up @@ -151,3 +152,34 @@ def get_bounding_range_2d(ax1, ax2):
]

return ax1_range, ax2_range


def get_data(obj, key: str) -> Any:
"""
Get the data for the quantity. If the quantity is a HDF5Reference, read the dataset
and corresponding units if available, and return a pint.Quantity.
Args:
obj (Any): The object to get the data from.
key (str): The key of the quantity.
Returns:
Any: The data for the quantity.
"""
return getattr(obj, key, None)


def set_data(obj, **kwargs):
"""
Set the data for the quantity. If the quantity is a HDF5Reference, the new value is
set in the HDF5 file at the corresponding path.
Args:
obj (Any): The object to set the data for.
"""
if not kwargs:
raise ValueError('At least one keyword argument must be provided.')

for key, value in kwargs.items():
if hasattr(obj, key):
setattr(obj, key, value)
58 changes: 38 additions & 20 deletions src/nomad_measurements/xrd/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@
from nomad_measurements.general import (
NOMADMeasurementsCategory,
)
from nomad_measurements.utils import get_bounding_range_2d, merge_sections
from nomad_measurements.xrd.nx import write_nx_section_and_create_file
from nomad_measurements.utils import (
get_bounding_range_2d,
get_data,
merge_sections,
set_data,
)

if TYPE_CHECKING:
import pint
Expand Down Expand Up @@ -358,12 +362,13 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'):
(dict, dict): line_linear, line_log
"""
plots = []
if self.two_theta is None or self.intensity is None:
two_theta = get_data(self, 'two_theta')
intensity = get_data(self, 'intensity')
if any([two_theta, intensity] == [None, None]):
return plots

x = self.two_theta.to('degree').magnitude
y = self.intensity.magnitude

x = two_theta.to('degree').magnitude
y = intensity.magnitude
fig_line_linear = px.line(
x=x,
y=y,
Expand Down Expand Up @@ -515,12 +520,15 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger'):
self.name = f'{self.scan_axis} Scan Result'
else:
self.name = 'XRD Scan Result'
q_norm = get_data(self, 'q_norm')
two_theta = get_data(self, 'two_theta')
if self.source_peak_wavelength is not None:
self.q_norm, self.two_theta = calculate_two_theta_or_q(
q_norm, two_theta = calculate_two_theta_or_q(
wavelength=self.source_peak_wavelength,
two_theta=self.two_theta,
q=self.q_norm,
two_theta=two_theta,
q=q_norm,
)
set_data(self, q_norm=q_norm, two_theta=two_theta)


class XRDResultRSM(XRDResult):
Expand Down Expand Up @@ -561,14 +569,17 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'):
(dict, dict): json_2theta_omega, json_q_vector
"""
plots = []
if self.two_theta is None or self.intensity is None or self.omega is None:
two_theta = get_data(self, 'two_theta')
intensity = get_data(self, 'intensity')
omega = get_data(self, 'omega')
if two_theta is None or intensity is None or omega is None:
return plots

# Plot for 2theta-omega RSM
# Zero values in intensity become -inf in log scale and are not plotted
x = self.omega.to('degree').magnitude
y = self.two_theta.to('degree').magnitude
z = self.intensity.magnitude
x = omega.to('degree').magnitude
y = two_theta.to('degree').magnitude
z = intensity.magnitude
log_z = np.log10(z)
x_range, y_range = get_bounding_range_2d(x, y)

Expand Down Expand Up @@ -637,8 +648,8 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'):

# Plot for RSM in Q-vectors
if self.q_parallel is not None and self.q_perpendicular is not None:
x = self.q_parallel.to('1/angstrom').magnitude.flatten()
y = self.q_perpendicular.to('1/angstrom').magnitude.flatten()
x = get_data(self, 'q_parallel').to('1/angstrom').magnitude.flatten()
y = get_data(self, 'q_perpendicular').to('1/angstrom').magnitude.flatten()
# q_vectors lead to irregular grid
# generate a regular grid using interpolation
x_regular = np.linspace(x.min(), x.max(), z.shape[0])
Expand Down Expand Up @@ -721,19 +732,26 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'):

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger'):
super().normalize(archive, logger)

if self.name is None:
self.name = 'RSM Scan Result'
var_axis = 'omega'

if self.source_peak_wavelength is not None:
for var_axis in ['omega', 'chi', 'phi']:
var_axis_value = get_data(self, var_axis)
if (
self[var_axis] is not None
and len(np.unique(self[var_axis].magnitude)) > 1
var_axis_value is not None
and len(np.unique(var_axis_value.magnitude)) > 1
):
self.q_parallel, self.q_perpendicular = calculate_q_vectors_RSM(
q_parallel, q_perpendicular = calculate_q_vectors_RSM(
wavelength=self.source_peak_wavelength,
two_theta=self.two_theta * np.ones_like(self.intensity),
omega=self[var_axis],
omega=var_axis_value,
)
set_data(
self,
q_parallel=q_parallel,
q_perpendicular=q_perpendicular,
)
break

Expand Down

0 comments on commit 278e592

Please sign in to comment.