From 278e592361b17745a800555d950037760ba1fd47 Mon Sep 17 00:00:00 2001 From: Sarthak Kapoor Date: Thu, 5 Sep 2024 17:59:28 +0200 Subject: [PATCH] Abstract out data interaction using setter and getter; allows to use same methods for classes with hdf5 refs --- src/nomad_measurements/utils.py | 32 +++++++++++++++ src/nomad_measurements/xrd/schema.py | 58 ++++++++++++++++++---------- 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/src/nomad_measurements/utils.py b/src/nomad_measurements/utils.py index ec60a06d..d4e6f401 100644 --- a/src/nomad_measurements/utils.py +++ b/src/nomad_measurements/utils.py @@ -18,6 +18,7 @@ import os.path from typing import ( TYPE_CHECKING, + Any, ) import numpy as np @@ -151,3 +152,34 @@ def get_bounding_range_2d(ax1, ax2): ] return ax1_range, ax2_range + + +def get_data(obj, key: str) -> Any: + """ + Get the data for the quantity. If the quantity is a HDF5Reference, read the dataset + and corresponding units if available, and return a pint.Quantity. + + Args: + obj (Any): The object to get the data from. + key (str): The key of the quantity. + + Returns: + Any: The data for the quantity. + """ + return getattr(obj, key, None) + + +def set_data(obj, **kwargs): + """ + Set the data for the quantity. If the quantity is a HDF5Reference, the new value is + set in the HDF5 file at the corresponding path. + + Args: + obj (Any): The object to set the data for. + """ + if not kwargs: + raise ValueError('At least one keyword argument must be provided.') + + for key, value in kwargs.items(): + if hasattr(obj, key): + setattr(obj, key, value) diff --git a/src/nomad_measurements/xrd/schema.py b/src/nomad_measurements/xrd/schema.py index 0b280d25..744dbba0 100644 --- a/src/nomad_measurements/xrd/schema.py +++ b/src/nomad_measurements/xrd/schema.py @@ -67,8 +67,12 @@ from nomad_measurements.general import ( NOMADMeasurementsCategory, ) -from nomad_measurements.utils import get_bounding_range_2d, merge_sections -from nomad_measurements.xrd.nx import write_nx_section_and_create_file +from nomad_measurements.utils import ( + get_bounding_range_2d, + get_data, + merge_sections, + set_data, +) if TYPE_CHECKING: import pint @@ -358,12 +362,13 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'): (dict, dict): line_linear, line_log """ plots = [] - if self.two_theta is None or self.intensity is None: + two_theta = get_data(self, 'two_theta') + intensity = get_data(self, 'intensity') + if any([two_theta, intensity] == [None, None]): return plots - x = self.two_theta.to('degree').magnitude - y = self.intensity.magnitude - + x = two_theta.to('degree').magnitude + y = intensity.magnitude fig_line_linear = px.line( x=x, y=y, @@ -515,12 +520,15 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger'): self.name = f'{self.scan_axis} Scan Result' else: self.name = 'XRD Scan Result' + q_norm = get_data(self, 'q_norm') + two_theta = get_data(self, 'two_theta') if self.source_peak_wavelength is not None: - self.q_norm, self.two_theta = calculate_two_theta_or_q( + q_norm, two_theta = calculate_two_theta_or_q( wavelength=self.source_peak_wavelength, - two_theta=self.two_theta, - q=self.q_norm, + two_theta=two_theta, + q=q_norm, ) + set_data(self, q_norm=q_norm, two_theta=two_theta) class XRDResultRSM(XRDResult): @@ -561,14 +569,17 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'): (dict, dict): json_2theta_omega, json_q_vector """ plots = [] - if self.two_theta is None or self.intensity is None or self.omega is None: + two_theta = get_data(self, 'two_theta') + intensity = get_data(self, 'intensity') + omega = get_data(self, 'omega') + if two_theta is None or intensity is None or omega is None: return plots # Plot for 2theta-omega RSM # Zero values in intensity become -inf in log scale and are not plotted - x = self.omega.to('degree').magnitude - y = self.two_theta.to('degree').magnitude - z = self.intensity.magnitude + x = omega.to('degree').magnitude + y = two_theta.to('degree').magnitude + z = intensity.magnitude log_z = np.log10(z) x_range, y_range = get_bounding_range_2d(x, y) @@ -637,8 +648,8 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'): # Plot for RSM in Q-vectors if self.q_parallel is not None and self.q_perpendicular is not None: - x = self.q_parallel.to('1/angstrom').magnitude.flatten() - y = self.q_perpendicular.to('1/angstrom').magnitude.flatten() + x = get_data(self, 'q_parallel').to('1/angstrom').magnitude.flatten() + y = get_data(self, 'q_perpendicular').to('1/angstrom').magnitude.flatten() # q_vectors lead to irregular grid # generate a regular grid using interpolation x_regular = np.linspace(x.min(), x.max(), z.shape[0]) @@ -721,19 +732,26 @@ def generate_plots(self, archive: 'EntryArchive', logger: 'BoundLogger'): def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger'): super().normalize(archive, logger) + if self.name is None: self.name = 'RSM Scan Result' - var_axis = 'omega' + if self.source_peak_wavelength is not None: for var_axis in ['omega', 'chi', 'phi']: + var_axis_value = get_data(self, var_axis) if ( - self[var_axis] is not None - and len(np.unique(self[var_axis].magnitude)) > 1 + var_axis_value is not None + and len(np.unique(var_axis_value.magnitude)) > 1 ): - self.q_parallel, self.q_perpendicular = calculate_q_vectors_RSM( + q_parallel, q_perpendicular = calculate_q_vectors_RSM( wavelength=self.source_peak_wavelength, two_theta=self.two_theta * np.ones_like(self.intensity), - omega=self[var_axis], + omega=var_axis_value, + ) + set_data( + self, + q_parallel=q_parallel, + q_perpendicular=q_perpendicular, ) break